Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
827aa873
Unverified
Commit
827aa873
authored
Jan 27, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 27, 2025
Browse files
cleanup sgl-kernel kernels (#3175)
parent
f8ca66fb
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
147 additions
and
26 deletions
+147
-26
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+1
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-1
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+0
-2
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu
+140
-0
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+1
-5
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+1
-9
sgl-kernel/src/sgl-kernel/torch_extension.cc
sgl-kernel/src/sgl-kernel/torch_extension.cc
+2
-8
sgl-kernel/tests/test_norm.py
sgl-kernel/tests/test_norm.py
+1
-1
No files found.
.github/workflows/pr-test.yml
View file @
827aa873
...
...
@@ -51,6 +51,7 @@ jobs:
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft ==
false
runs-on
:
1-gpu-runner
strategy
:
fail-fast
:
false
matrix
:
range
:
[
0-6
,
6-15
,
15-22
,
22-32
,
32-40
,
40-100
]
steps
:
...
...
sgl-kernel/setup.py
View file @
827aa873
...
...
@@ -88,7 +88,7 @@ sources = [
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/
rotary_embedding
.cu"
,
"src/sgl-kernel/csrc/
fused_add_rms_norm_kernel
.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/bmm_fp8.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
827aa873
...
...
@@ -17,7 +17,6 @@ from sgl_kernel.ops import (
moe_align_block_size
,
register_graph_buffers
,
rmsnorm
,
rotary_embedding
,
sampling_scaling_penalties
,
silu_and_mul
,
top_k_renorm_prob
,
...
...
@@ -44,7 +43,6 @@ __all__ = [
"moe_align_block_size"
,
"register_graph_buffers"
,
"rmsnorm"
,
"rotary_embedding"
,
"sampling_scaling_penalties"
,
"silu_and_mul"
,
"top_k_renorm_prob"
,
...
...
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu
0 → 100644
View file @
827aa873
// Adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/include/flashinfer/norm.cuh
// and https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/python/csrc/norm.cu
// TODO(zhyncs): tmp fix, v0.1.6 enables SGLang e2e to pass CIs unlike v0.2.0
#include <ATen/cuda/CUDAContext.h>
#include <flashinfer/math.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include <numeric>
#include "utils.h"
using
namespace
flashinfer
;
template
<
uint32_t
VEC_SIZE
,
typename
T
>
__global__
void
FusedAddRMSNormKernel
(
T
*
__restrict__
input
,
T
*
__restrict__
residual
,
T
*
__restrict__
weight
,
const
uint32_t
d
,
float
eps
)
{
const
uint32_t
bx
=
blockIdx
.
x
;
const
uint32_t
tx
=
threadIdx
.
x
,
ty
=
threadIdx
.
y
;
constexpr
uint32_t
warp_size
=
32
;
const
uint32_t
num_warps
=
blockDim
.
y
;
const
uint32_t
thread_id
=
tx
+
ty
*
warp_size
;
const
uint32_t
num_threads
=
num_warps
*
warp_size
;
const
uint32_t
rounds
=
ceil_div
(
d
,
VEC_SIZE
*
num_threads
);
extern
__shared__
float
smem
[];
float
sum_sq
=
0.
f
;
for
(
uint32_t
i
=
0
;
i
<
rounds
;
i
++
)
{
vec_t
<
T
,
VEC_SIZE
>
input_vec
;
input_vec
.
fill
(
0.
f
);
vec_t
<
T
,
VEC_SIZE
>
residual_vec
;
residual_vec
.
fill
(
0.
f
);
if
((
i
*
num_threads
+
thread_id
)
*
VEC_SIZE
<
d
)
{
input_vec
.
load
(
input
+
bx
*
d
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
residual_vec
.
load
(
residual
+
bx
*
d
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
}
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
float
x
=
float
(
input_vec
[
j
]);
x
+=
float
(
residual_vec
[
j
]);
sum_sq
+=
x
*
x
;
residual_vec
[
j
]
=
(
T
)
x
;
}
if
((
i
*
num_threads
+
thread_id
)
*
VEC_SIZE
<
d
)
{
residual_vec
.
store
(
residual
+
bx
*
d
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
}
}
// first, warp reduce sum
#pragma unroll
for
(
uint32_t
offset
=
warp_size
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum_sq
+=
math
::
shfl_xor_sync
(
sum_sq
,
offset
);
}
smem
[
ty
]
=
sum_sq
;
__syncthreads
();
// then, cross warp reduce sum using only the first warp
if
(
ty
==
0
)
{
sum_sq
=
(
tx
<
num_warps
)
?
smem
[
tx
]
:
0.
f
;
#pragma unroll
for
(
uint32_t
offset
=
warp_size
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum_sq
+=
math
::
shfl_xor_sync
(
sum_sq
,
offset
);
}
smem
[
0
]
=
sum_sq
;
}
__syncthreads
();
float
rms_rcp
=
math
::
rsqrt
(
smem
[
0
]
/
float
(
d
)
+
eps
);
for
(
uint32_t
i
=
0
;
i
<
rounds
;
i
++
)
{
vec_t
<
T
,
VEC_SIZE
>
input_vec
;
vec_t
<
T
,
VEC_SIZE
>
weight_vec
;
vec_t
<
T
,
VEC_SIZE
>
residual_vec
;
input_vec
.
fill
(
0.
f
);
weight_vec
.
fill
(
0.
f
);
residual_vec
.
fill
(
0.
f
);
if
((
i
*
num_threads
+
thread_id
)
*
VEC_SIZE
<
d
)
{
input_vec
.
load
(
input
+
bx
*
d
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
weight_vec
.
load
(
weight
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
residual_vec
.
load
(
residual
+
bx
*
d
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
}
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
input_vec
[
j
]
=
float
(
residual_vec
[
j
])
*
rms_rcp
*
float
(
weight_vec
[
j
]);
}
if
((
i
*
num_threads
+
thread_id
)
*
VEC_SIZE
<
d
)
{
input_vec
.
store
(
input
+
bx
*
d
+
i
*
num_threads
*
VEC_SIZE
+
thread_id
*
VEC_SIZE
);
}
}
}
template
<
typename
T
>
cudaError_t
FusedAddRMSNorm
(
T
*
input
,
T
*
residual
,
T
*
weight
,
uint32_t
batch_size
,
uint32_t
d
,
float
eps
=
1e-5
,
cudaStream_t
stream
=
0
)
{
const
uint32_t
vec_size
=
std
::
gcd
(
16
/
sizeof
(
T
),
d
);
const
uint32_t
block_size
=
std
::
min
<
uint32_t
>
(
1024
,
d
/
vec_size
);
const
uint32_t
num_warps
=
ceil_div
(
block_size
,
32
);
dim3
nblks
(
batch_size
);
dim3
nthrs
(
32
,
num_warps
);
const
uint32_t
smem_size
=
num_warps
*
sizeof
(
float
);
void
*
args
[]
=
{
&
input
,
&
residual
,
&
weight
,
&
d
,
&
eps
};
DISPATCH_ALIGNED_VEC_SIZE
(
vec_size
,
VEC_SIZE
,
{
auto
kernel
=
FusedAddRMSNormKernel
<
VEC_SIZE
,
T
>
;
FLASHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel
,
nblks
,
nthrs
,
args
,
smem_size
,
stream
));
});
return
cudaSuccess
;
}
void
sgl_fused_add_rmsnorm
(
torch
::
Tensor
input
,
torch
::
Tensor
residual
,
torch
::
Tensor
weight
,
double
eps
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
residual
);
CHECK_INPUT
(
weight
);
auto
device
=
input
.
device
();
CHECK_EQ
(
residual
.
device
(),
device
);
CHECK_EQ
(
weight
.
device
(),
device
);
CHECK_DIM
(
2
,
input
);
// input: (batch_size, hidden_size)
CHECK_DIM
(
2
,
residual
);
// residual: (batch_size, hidden_size)
CHECK_DIM
(
1
,
weight
);
// weight: (hidden_size)
CHECK_EQ
(
input
.
size
(
0
),
residual
.
size
(
0
));
CHECK_EQ
(
input
.
size
(
1
),
residual
.
size
(
1
));
CHECK_EQ
(
input
.
size
(
1
),
weight
.
size
(
0
));
unsigned
int
batch_size
=
input
.
size
(
0
);
unsigned
int
hidden_size
=
input
.
size
(
1
);
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// support float16, bfloat16 and float32
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
cudaError_t
status
=
FusedAddRMSNorm
(
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
static_cast
<
c_type
*>
(
residual
.
data_ptr
()),
static_cast
<
c_type
*>
(
weight
.
data_ptr
()),
batch_size
,
hidden_size
,
eps
,
torch_current_stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"FusedAddRMSNorm failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
return
true
;
});
}
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
827aa873
...
...
@@ -50,15 +50,11 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
);
// rotary embedding
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
// rms norm
void
rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
// fused rms norm
void
fused_add_rmsnorm
(
a
t
::
Tensor
&
input
,
a
t
::
Tensor
&
residual
,
a
t
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
sgl_
fused_add_rmsnorm
(
t
orch
::
Tensor
input
,
t
orch
::
Tensor
residual
,
t
orch
::
Tensor
weight
,
double
eps
);
// gemma rms norm
void
gemma_rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
...
...
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
827aa873
...
...
@@ -142,12 +142,6 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
)
def
rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
):
return
torch
.
ops
.
sgl_kernels
.
rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
)
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
# Kudos to @yzh119
def
rmsnorm
(
...
...
@@ -167,9 +161,7 @@ def fused_add_rmsnorm(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
with
input
.
device
as
device
:
torch
.
ops
.
sgl_kernels
.
fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
,
_get_cuda_stream
(
device
)
)
torch
.
ops
.
sgl_kernels
.
fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
)
def
gemma_rmsnorm
(
...
...
sgl-kernel/src/sgl-kernel/torch_extension.cc
View file @
827aa873
...
...
@@ -45,19 +45,13 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"new_kv) -> ()"
);
m
.
impl
(
"lightning_attention_decode"
,
torch
::
kCUDA
,
&
lightning_attention_decode
);
// rotary embedding
m
.
def
(
"rotary_embedding(Tensor positions, Tensor! query, Tensor! key, int head_size, Tensor cos_sin_cache, bool "
"is_neox) -> ()"
);
m
.
impl
(
"rotary_embedding"
,
torch
::
kCUDA
,
&
rotary_embedding
);
// rms norm
m
.
def
(
"rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"
);
m
.
impl
(
"rmsnorm"
,
torch
::
kCUDA
,
&
rmsnorm
);
// fused rms norm
m
.
def
(
"fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps
, int cuda_stream
) -> ()"
);
m
.
impl
(
"fused_add_rmsnorm"
,
torch
::
kCUDA
,
&
fused_add_rmsnorm
);
m
.
def
(
"fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"
);
m
.
impl
(
"fused_add_rmsnorm"
,
torch
::
kCUDA
,
&
sgl_
fused_add_rmsnorm
);
// gemma rms norm
m
.
def
(
"gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"
);
...
...
sgl-kernel/tests/test_norm.py
View file @
827aa873
...
...
@@ -69,7 +69,7 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
111
,
500
,
1024
,
3072
,
3584
,
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
])
def
test_fused_add_rmsnorm
(
batch_size
,
hidden_size
,
dtype
):
eps
=
1e-6
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment