Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
04f0b4cb
Unverified
Commit
04f0b4cb
authored
Jan 24, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 24, 2025
Browse files
minor: update sgl-kernel setup (#3107)
parent
4505a436
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
103 additions
and
15 deletions
+103
-15
sgl-kernel/setup.py
sgl-kernel/setup.py
+11
-15
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu
+92
-0
No files found.
sgl-kernel/setup.py
View file @
04f0b4cb
...
@@ -38,10 +38,10 @@ def _get_version():
...
@@ -38,10 +38,10 @@ def _get_version():
return
line
.
split
(
"="
)[
1
].
strip
().
strip
(
'"'
)
return
line
.
split
(
"="
)[
1
].
strip
().
strip
(
'"'
)
cutlass
=
root
/
"3rdparty"
/
"cutlass"
cutlass_default
=
root
/
"3rdparty"
/
"cutlass"
cutlass_default
=
root
/
"3rdparty"
/
"cutlass"
cutlass
=
Path
(
os
.
environ
.
get
(
"CUSTOM_CUTLASS_SRC_DIR"
,
default
=
cutlass_default
))
cutlass
=
Path
(
os
.
environ
.
get
(
"CUSTOM_CUTLASS_SRC_DIR"
,
default
=
cutlass_default
))
flashinfer
=
root
/
"3rdparty"
/
"flashinfer"
flashinfer
=
root
/
"3rdparty"
/
"flashinfer"
turbomind
=
root
/
"3rdparty"
/
"turbomind"
include_dirs
=
[
include_dirs
=
[
cutlass
.
resolve
()
/
"include"
,
cutlass
.
resolve
()
/
"include"
,
cutlass
.
resolve
()
/
"tools"
/
"util"
/
"include"
,
cutlass
.
resolve
()
/
"tools"
/
"util"
/
"include"
,
...
@@ -49,6 +49,8 @@ include_dirs = [
...
@@ -49,6 +49,8 @@ include_dirs = [
flashinfer
.
resolve
()
/
"include"
,
flashinfer
.
resolve
()
/
"include"
,
flashinfer
.
resolve
()
/
"include"
/
"gemm"
,
flashinfer
.
resolve
()
/
"include"
/
"gemm"
,
flashinfer
.
resolve
()
/
"csrc"
,
flashinfer
.
resolve
()
/
"csrc"
,
turbomind
.
resolve
(),
turbomind
.
resolve
()
/
"src"
,
]
]
nvcc_flags
=
[
nvcc_flags
=
[
"-DNDEBUG"
,
"-DNDEBUG"
,
...
@@ -63,6 +65,11 @@ nvcc_flags = [
...
@@ -63,6 +65,11 @@ nvcc_flags = [
"-use_fast_math"
,
"-use_fast_math"
,
"-DFLASHINFER_ENABLE_F16"
,
"-DFLASHINFER_ENABLE_F16"
,
]
]
nvcc_flags_fp8
=
[
"-DFLASHINFER_ENABLE_FP8"
,
"-DFLASHINFER_ENABLE_FP8_E4M3"
,
"-DFLASHINFER_ENABLE_FP8_E5M2"
,
]
sources
=
[
sources
=
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu"
,
"src/sgl-kernel/csrc/trt_reduce_internal.cu"
,
...
@@ -73,6 +80,7 @@ sources = [
...
@@ -73,6 +80,7 @@ sources = [
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
"src/sgl-kernel/csrc/fused_add_rms_norm.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/bmm_fp8.cu"
,
"3rdparty/flashinfer/csrc/bmm_fp8.cu"
,
"3rdparty/flashinfer/csrc/group_gemm.cu"
,
"3rdparty/flashinfer/csrc/group_gemm.cu"
,
...
@@ -92,13 +100,7 @@ if torch.cuda.is_available():
...
@@ -92,13 +100,7 @@ if torch.cuda.is_available():
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
sources
.
append
(
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu"
)
sources
.
append
(
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu"
)
if
sm_version
>=
90
:
if
sm_version
>=
90
:
nvcc_flags
.
extend
(
nvcc_flags
.
extend
(
nvcc_flags_fp8
)
[
"-DFLASHINFER_ENABLE_FP8"
,
"-DFLASHINFER_ENABLE_FP8_E4M3"
,
"-DFLASHINFER_ENABLE_FP8_E5M2"
,
]
)
if
sm_version
>=
80
:
if
sm_version
>=
80
:
nvcc_flags
.
append
(
"-DFLASHINFER_ENABLE_BF16"
)
nvcc_flags
.
append
(
"-DFLASHINFER_ENABLE_BF16"
)
else
:
else
:
...
@@ -107,13 +109,7 @@ else:
...
@@ -107,13 +109,7 @@ else:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
sources
.
append
(
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu"
)
sources
.
append
(
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu"
)
if
enable_fp8
:
if
enable_fp8
:
nvcc_flags
.
extend
(
nvcc_flags
.
extend
(
nvcc_flags_fp8
)
[
"-DFLASHINFER_ENABLE_FP8"
,
"-DFLASHINFER_ENABLE_FP8_E4M3"
,
"-DFLASHINFER_ENABLE_FP8_E5M2"
,
]
)
if
enable_bf16
:
if
enable_bf16
:
nvcc_flags
.
append
(
"-DFLASHINFER_ENABLE_BF16"
)
nvcc_flags
.
append
(
"-DFLASHINFER_ENABLE_BF16"
)
...
...
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu
0 → 100644
View file @
04f0b4cb
// Adapted from
// https://github.com/InternLM/lmdeploy/blob/800b6010c0bf76aadf678bc38a507b749fb9774c/src/turbomind/kernels/norm/rms_norm.cu
#include <turbomind/kernels/core/array_ops.h>
#include <turbomind/kernels/core/common.h>
#include <cub/block/block_reduce.cuh>
using
namespace
turbomind
;
template
<
class
T
,
class
Tacc
,
int
block_dim
,
int
vec_size
>
__global__
void
BiasResidualRMSNormKernel
(
T
*
__restrict__
residual
,
T
*
__restrict__
hidden_states
,
const
T
*
__restrict__
weights
,
const
T
*
__restrict__
bias
,
int
dims
,
int
num
,
float
eps
,
float
inv_dims
)
{
const
int
ti
=
blockIdx
.
x
;
const
int
di
=
threadIdx
.
x
*
vec_size
;
if
(
ti
>=
num
)
{
return
;
}
residual
+=
dims
*
ti
;
hidden_states
+=
dims
*
ti
;
Array
<
Tacc
,
vec_size
>
accum
{};
Array
<
T
,
vec_size
>
r_vec
;
Array
<
T
,
vec_size
>
h_vec
;
Array
<
T
,
vec_size
>
b_vec
;
for
(
int
i
=
di
;
i
<
dims
;
i
+=
block_dim
*
vec_size
)
{
Load
(
r_vec
,
&
residual
[
i
]);
Load
(
h_vec
,
&
hidden_states
[
i
]);
using
namespace
ops
;
r_vec
=
r_vec
+
h_vec
;
if
(
bias
)
{
Ldg
(
b_vec
,
&
bias
[
i
]);
r_vec
=
r_vec
+
b_vec
;
}
Store
(
&
residual
[
i
],
r_vec
);
Array
<
Tacc
,
vec_size
>
tmp
=
cast
<
Tacc
>
(
r_vec
);
accum
=
accum
+
tmp
*
tmp
;
}
float
sum
{};
PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
vec_size
;
++
i
)
{
sum
+=
accum
[
i
];
}
using
BlockReduce
=
cub
::
BlockReduce
<
Tacc
,
block_dim
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
sum
=
BlockReduce
{
temp_storage
}.
Sum
(
sum
);
__shared__
float
shared_sum
;
if
(
threadIdx
.
x
==
0
)
{
shared_sum
=
rsqrtf
(
sum
*
inv_dims
+
eps
);
}
__syncthreads
();
sum
=
shared_sum
;
Array
<
T
,
vec_size
>
w_vec
;
for
(
int
i
=
di
;
i
<
dims
;
i
+=
block_dim
*
vec_size
)
{
Load
(
r_vec
,
&
residual
[
i
]);
Ldg
(
w_vec
,
&
weights
[
i
]);
PRAGMA_UNROLL
for
(
int
c
=
0
;
c
<
vec_size
;
++
c
)
{
r_vec
[
c
]
=
(
T
)((
float
)
r_vec
[
c
]
*
sum
)
*
w_vec
[
c
];
}
Store
(
&
hidden_states
[
i
],
r_vec
);
}
}
template
<
class
T
>
void
invokeBiasResidualRMSNorm
(
T
*
residual
,
T
*
hidden_states
,
const
T
*
weights
,
const
T
*
bias
,
int
dims
,
int
num
,
float
eps
,
cudaStream_t
st
)
{
constexpr
int
vec_size
=
16
/
sizeof
(
T
);
constexpr
int
threads
=
512
;
const
int
blocks
=
num
;
BiasResidualRMSNormKernel
<
T
,
float
,
threads
,
vec_size
>
<<<
blocks
,
threads
,
0
,
st
>>>
(
residual
,
hidden_states
,
weights
,
bias
,
dims
,
num
,
eps
,
1.
f
/
dims
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment