Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c38b5fb4
Unverified
Commit
c38b5fb4
authored
Jan 30, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 30, 2025
Browse files
update 3rdparty and rms norm for sgl-kernel (#3213)
parent
20453cef
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
8 additions
and
113 deletions
+8
-113
sgl-kernel/3rdparty/cutlass
sgl-kernel/3rdparty/cutlass
+1
-1
sgl-kernel/3rdparty/flashinfer
sgl-kernel/3rdparty/flashinfer
+1
-1
sgl-kernel/pyproject.toml
sgl-kernel/pyproject.toml
+1
-1
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu
+4
-109
sgl-kernel/version.py
sgl-kernel/version.py
+1
-1
No files found.
cutlass
@
bdd64179
Compare
b78588d1
...
bdd64179
Subproject commit b
78588d1630aa6643bf021613717bafb705df4ef
Subproject commit b
dd641790ad49353b40ada41330552a78d2f8b5a
flashinfer
@
e5a3befb
Compare
4f1f0898
...
e5a3befb
Subproject commit
4f1f08989c71f92df181e346548c2ca48ae6daf5
Subproject commit
e5a3befbe3e63025f0158bc96b218a9c5f402ac7
sgl-kernel/pyproject.toml
View file @
c38b5fb4
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
[project]
name
=
"sgl-kernel"
name
=
"sgl-kernel"
version
=
"0.0.3"
version
=
"0.0.3
.post1
"
description
=
"Kernel Library for SGLang"
description
=
"Kernel Library for SGLang"
readme
=
"README.md"
readme
=
"README.md"
requires-python
=
">=3.9"
requires-python
=
">=3.9"
...
...
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu
View file @
c38b5fb4
// Adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/include/flashinfer/norm.cuh
// and https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/python/csrc/norm.cu
// TODO(zhyncs): tmp fix, v0.1.6 enables SGLang e2e to pass CIs unlike v0.2.0
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <flashinfer/math.cuh>
#include <flashinfer/norm.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include <numeric>
#include "utils.h"
#include "utils.h"
using
namespace
flashinfer
;
using
namespace
flashinfer
;
template
<
uint32_t
VEC_SIZE
,
typename
T
>
__global__
void
FusedAddRMSNormKernel
(
T
*
__restrict__
input
,
T
*
__restrict__
residual
,
T
*
__restrict__
weight
,
const
uint32_t
d
,
float
eps
)
{
const
uint32_t
bx
=
blockIdx
.
x
;
const
uint32_t
tx
=
threadIdx
.
x
,
ty
=
threadIdx
.
y
;
constexpr
uint32_t
warp_size
=
32
;
const
uint32_t
num_warps
=
blockDim
.
y
;
const
uint32_t
thread_id
=
tx
+
ty
*
warp_size
;
const
uint32_t
num_threads
=
num_warps
*
warp_size
;
const
uint32_t
rounds
=
ceil_div
(
d
,
VEC_SIZE
*
num_threads
);
extern
__shared__
float
smem
[];
float
sum_sq
=
0.
f
;
for
(
uint32_t
i
=
0
;
i
<
rounds
;
i
++
)
{
vec_t
<
T
,
VEC_SIZE
>
input_vec
;
input_vec
.
fill
(
0.
f
);
vec_t
<
T
,
VEC_SIZE
>
residual_vec
;
residual_vec
.
fill
(
0.
f
);
if
((
i
*
num_threads
+
thread_id
)
*
VEC_SIZE
<
d
)
{
input_vec
.
load
(
input
+
bx
*
d
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
residual_vec
.
load
(
residual
+
bx
*
d
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
}
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
float
x
=
float
(
input_vec
[
j
]);
x
+=
float
(
residual_vec
[
j
]);
sum_sq
+=
x
*
x
;
residual_vec
[
j
]
=
(
T
)
x
;
}
if
((
i
*
num_threads
+
thread_id
)
*
VEC_SIZE
<
d
)
{
residual_vec
.
store
(
residual
+
bx
*
d
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
}
}
// first, warp reduce sum
#pragma unroll
for
(
uint32_t
offset
=
warp_size
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum_sq
+=
math
::
shfl_xor_sync
(
sum_sq
,
offset
);
}
smem
[
ty
]
=
sum_sq
;
__syncthreads
();
// then, cross warp reduce sum using only the first warp
if
(
ty
==
0
)
{
sum_sq
=
(
tx
<
num_warps
)
?
smem
[
tx
]
:
0.
f
;
#pragma unroll
for
(
uint32_t
offset
=
warp_size
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum_sq
+=
math
::
shfl_xor_sync
(
sum_sq
,
offset
);
}
smem
[
0
]
=
sum_sq
;
}
__syncthreads
();
float
rms_rcp
=
math
::
rsqrt
(
smem
[
0
]
/
float
(
d
)
+
eps
);
for
(
uint32_t
i
=
0
;
i
<
rounds
;
i
++
)
{
vec_t
<
T
,
VEC_SIZE
>
input_vec
;
vec_t
<
T
,
VEC_SIZE
>
weight_vec
;
vec_t
<
T
,
VEC_SIZE
>
residual_vec
;
input_vec
.
fill
(
0.
f
);
weight_vec
.
fill
(
0.
f
);
residual_vec
.
fill
(
0.
f
);
if
((
i
*
num_threads
+
thread_id
)
*
VEC_SIZE
<
d
)
{
input_vec
.
load
(
input
+
bx
*
d
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
weight_vec
.
load
(
weight
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
residual_vec
.
load
(
residual
+
bx
*
d
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
}
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
input_vec
[
j
]
=
float
(
residual_vec
[
j
])
*
rms_rcp
*
float
(
weight_vec
[
j
]);
}
if
((
i
*
num_threads
+
thread_id
)
*
VEC_SIZE
<
d
)
{
input_vec
.
store
(
input
+
bx
*
d
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
}
}
}
template
<
typename
T
>
cudaError_t
FusedAddRMSNorm
(
T
*
input
,
T
*
residual
,
T
*
weight
,
uint32_t
batch_size
,
uint32_t
d
,
float
eps
=
1e-5
,
cudaStream_t
stream
=
0
)
{
const
uint32_t
vec_size
=
std
::
gcd
(
16
/
sizeof
(
T
),
d
);
const
uint32_t
block_size
=
std
::
min
<
uint32_t
>
(
1024
,
d
/
vec_size
);
const
uint32_t
num_warps
=
ceil_div
(
block_size
,
32
);
dim3
nblks
(
batch_size
);
dim3
nthrs
(
32
,
num_warps
);
const
uint32_t
smem_size
=
num_warps
*
sizeof
(
float
);
void
*
args
[]
=
{
&
input
,
&
residual
,
&
weight
,
&
d
,
&
eps
};
DISPATCH_ALIGNED_VEC_SIZE
(
vec_size
,
VEC_SIZE
,
{
auto
kernel
=
FusedAddRMSNormKernel
<
VEC_SIZE
,
T
>
;
FLASHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel
,
nblks
,
nthrs
,
args
,
smem_size
,
stream
));
});
return
cudaSuccess
;
}
void
sgl_fused_add_rmsnorm
(
torch
::
Tensor
input
,
torch
::
Tensor
residual
,
torch
::
Tensor
weight
,
double
eps
)
{
void
sgl_fused_add_rmsnorm
(
torch
::
Tensor
input
,
torch
::
Tensor
residual
,
torch
::
Tensor
weight
,
double
eps
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
residual
);
CHECK_INPUT
(
residual
);
...
@@ -130,9 +25,9 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
...
@@ -130,9 +25,9 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// support float16, bfloat16 and float32
// support float16, bfloat16 and float32
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
cudaError_t
status
=
cudaError_t
status
=
norm
::
FusedAddRMSNorm
(
FusedAddRMSNorm
(
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
static_cast
<
c_type
*>
(
residual
.
data_ptr
()),
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
static_cast
<
c_type
*>
(
residual
.
data_ptr
()),
static_cast
<
c_type
*>
(
weight
.
data_ptr
()),
batch_size
,
hidden_size
,
eps
,
torch_current_stream
);
static_cast
<
c_type
*>
(
weight
.
data_ptr
()),
batch_size
,
hidden_size
,
eps
,
torch_current_stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
TORCH_CHECK
(
status
==
cudaSuccess
,
"FusedAddRMSNorm failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
"FusedAddRMSNorm failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
return
true
;
return
true
;
...
...
sgl-kernel/version.py
View file @
c38b5fb4
__version__
=
"0.0.3"
__version__
=
"0.0.3
.post1
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment