Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c4e314f9
Unverified
Commit
c4e314f9
authored
Sep 25, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Sep 25, 2025
Browse files
Restruct sgl-kernel benchmark (#10861)
parent
7a06ef98
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
342 additions
and
15 deletions
+342
-15
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
+1
-1
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
+1
-1
sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py
sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py
+4
-6
sgl-kernel/benchmark/bench_rmsnorm.py
sgl-kernel/benchmark/bench_rmsnorm.py
+318
-0
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
+3
-1
sgl-kernel/tests/test_norm.py
sgl-kernel/tests/test_norm.py
+15
-6
No files found.
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
View file @
c4e314f9
...
...
@@ -88,7 +88,7 @@ def benchmark(batch_size, seq_len, provider):
elif
provider
==
"sglang"
:
fn
=
lambda
:
sglang_scaled_fp8_quant
(
x
.
clone
())
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
quantiles
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
_cudagraph
(
fn
,
quantiles
=
quantiles
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
...
...
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
View file @
c4e314f9
...
...
@@ -160,7 +160,7 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
elif
provider
==
"sglang"
:
fn
=
lambda
:
sglang_per_token_quant_fp8
(
x
.
clone
())
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
quantiles
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
_cudagraph
(
fn
,
quantiles
=
quantiles
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
...
...
sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py
View file @
c4e314f9
...
...
@@ -117,17 +117,17 @@ def benchmark(batch_size, provider, N, K):
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"FP16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
_cudagraph
(
lambda
:
torch
.
matmul
(
a_fp16
,
b_fp16
),
quantiles
=
quantiles
,
)
if
provider
==
"W8A8"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
_cudagraph
(
lambda
:
int8_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
),
quantiles
=
quantiles
,
)
if
provider
==
"Qserve_W4A8_Per_Channel"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
_cudagraph
(
lambda
:
qserve_w4a8_per_chn_gemm
(
a_qserve_chn
,
b_qserve_chn
,
...
...
@@ -139,7 +139,7 @@ def benchmark(batch_size, provider, N, K):
quantiles
=
quantiles
,
)
if
provider
==
"Qserve_W4A8_Per_Group"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
_cudagraph
(
lambda
:
qserve_w4a8_per_group_gemm
(
a_qserve_group
,
b_qserve_group
,
...
...
@@ -189,8 +189,6 @@ if __name__ == "__main__":
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_qserve_w4a8_gemm_res"
,
N
=
N
,
K
=
K
,
)
...
...
benchmark/kernels/rmsnorm
/benchmark_rmsnorm.py
→
sgl-kernel
/benchmark
/bench
_rmsnorm.py
View file @
c4e314f9
# Benchmarks SGLang RMSNorm kernels versus vLLM and FlashInfer across
# (batch_size, seq_len, hidden_size) and prints speed-up.
import
argparse
import
itertools
from
typing
import
Optional
,
Tuple
,
Union
import
re
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
sgl_kernel
import
torch
import
torch.nn
as
nn
import
triton
import
triton.testing
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
from
torch
import
nn
from
sgl_kernel.utils
import
is_arch_support_pdl
from
vllm
import
_custom_ops
as
vllm_ops
def
str2int_list
(
arg
:
str
)
->
List
[
int
]:
if
arg
in
(
""
,
None
):
return
[]
if
re
.
fullmatch
(
r
"\d+(,\d+)*"
,
arg
.
strip
())
is
None
:
raise
argparse
.
ArgumentTypeError
(
f
"Bad int list:
{
arg
}
"
)
return
[
int
(
x
)
for
x
in
arg
.
split
(
","
)]
class
HuggingFaceRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
)
->
None
:
super
().
__init__
()
...
...
@@ -108,6 +123,36 @@ def rmsnorm_vllm(
return
output
def
rmsnorm_sglang
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
eps
:
float
=
1e-6
,
enable_pdl
:
Optional
[
bool
]
=
None
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
residual
is
not
None
:
residual
=
residual
.
view
(
-
1
,
residual
.
shape
[
-
1
])
if
enable_pdl
is
None
:
enable_pdl
=
is_arch_support_pdl
()
if
residual
is
not
None
:
sgl_kernel
.
fused_add_rmsnorm
(
x
,
residual
,
weight
,
eps
,
enable_pdl
=
enable_pdl
)
output
=
(
x
,
residual
)
else
:
out
=
torch
.
empty_like
(
x
)
sgl_kernel
.
rmsnorm
(
x
,
weight
,
eps
,
out
=
out
,
enable_pdl
=
enable_pdl
)
output
=
out
if
isinstance
(
output
,
tuple
):
output
=
(
output
[
0
].
view
(
orig_shape
),
output
[
1
].
view
(
orig_shape
))
else
:
output
=
output
.
view
(
orig_shape
)
return
output
def
calculate_diff
(
batch_size
,
seq_len
,
hidden_size
,
use_residual
=
True
):
dtype
=
torch
.
bfloat16
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
...
...
@@ -123,108 +168,151 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
output_vllm
=
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
output_sglang
=
rmsnorm_sglang
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
if
use_residual
:
output_naive
=
output_naive
[
0
]
output_flashinfer
=
output_flashinfer
[
0
]
output_vllm
=
output_vllm
[
0
]
output_sglang
=
output_sglang
[
0
]
print
(
f
"Naive output=
{
output_naive
}
"
)
print
(
f
"FlashInfer output=
{
output_flashinfer
}
"
)
print
(
f
"VLLM output=
{
output_vllm
}
"
)
print
(
f
"SGLang output=
{
output_sglang
}
"
)
if
torch
.
allclose
(
output_naive
,
output_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
):
if
(
torch
.
allclose
(
output_naive
,
output_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
output_naive
,
output_sglang
,
atol
=
1e-2
,
rtol
=
1e-2
)
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
seq_length_range
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
head_num_range
=
[
32
,
48
]
configs
=
list
(
itertools
.
product
(
head_num_range
,
batch_size_range
,
seq_length_range
))
def
get_benchmark
(
use_residual
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"head_num"
,
"batch_size"
,
"seq_len"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"huggingface"
,
"flashinfer"
,
"vllm"
],
line_names
=
[
"HuggingFace"
,
"FlashInfer"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
f
"rmsnorm-performance-
{
'with'
if
use_residual
else
'without'
}
-residual"
,
args
=
{},
)
default_batch_sizes
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
# 1, 4, 16, 64
default_seq_lens
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
# 64, 128, 256, 512, 1024
default_hidden_sizes
=
[
32
*
128
,
48
*
128
]
# 4096, 6144
def
make_configs
(
bsizes
:
List
[
int
],
slens
:
List
[
int
],
hsizes
:
List
[
int
])
->
List
[
Tuple
]:
return
list
(
itertools
.
product
(
bsizes
,
slens
,
hsizes
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
,
"hidden_size"
],
x_vals
=
[],
line_arg
=
"provider"
,
line_vals
=
[
"huggingface"
,
"flashinfer"
,
"vllm"
,
"sglang"
],
line_names
=
[
"HuggingFace"
,
"FlashInfer"
,
"vLLM"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
),
(
"orange"
,
"-"
)],
ylabel
=
"µs (median) or × (speed-up)"
,
plot_name
=
"rmsnorm-performance"
,
args
=
{},
)
def
benchmark
(
head_num
,
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
bfloat16
hidden_size
=
head_num
*
128
# assuming head_dim = 128
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"huggingface"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rmsnorm_naive
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
),
quantiles
=
quantiles
,
)
def
benchmark
(
batch_size
,
seq_len
,
hidden_size
,
provider
,
use_residual
):
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
device
)
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
# timing helper
def
timed
(
fn
):
for
_
in
range
(
5
):
fn
()
torch
.
cuda
.
synchronize
()
ms
,
qmin
,
qmax
=
triton
.
testing
.
do_bench_cudagraph
(
fn
,
quantiles
=
[
0.5
,
0.2
,
0.8
]
)
return
1000
*
ms
,
1000
*
qmax
,
1000
*
qmin
if
provider
==
"huggingface"
:
return
timed
(
lambda
:
rmsnorm_naive
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
)
elif
provider
==
"flashinfer"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rmsnorm_flashinfer
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
),
quantiles
=
quantiles
,
)
elif
provider
==
"flashinfer"
:
return
timed
(
lambda
:
rmsnorm_flashinfer
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
),
quantiles
=
quantiles
,
)
elif
provider
==
"vllm"
:
return
timed
(
lambda
:
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
)
)
elif
provider
==
"sglang"
:
return
timed
(
lambda
:
rmsnorm_sglang
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
)
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
# provider == "speedup"
t_ref
,
_
,
_
=
timed
(
lambda
:
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
)
)
t_sgl
,
_
,
_
=
timed
(
lambda
:
rmsnorm_sglang
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
)
)
spd
=
t_ref
/
t_sgl
return
(
spd
,
spd
,
spd
)
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
p
=
argparse
.
ArgumentParser
(
"RMSNorm kernel benchmark"
)
p
.
add_argument
(
"--batch_sizes"
,
type
=
str2int_list
,
default
=
default_batch_sizes
)
p
.
add_argument
(
"--seq_lens"
,
type
=
str2int_list
,
default
=
default_seq_lens
)
p
.
add_argument
(
"--hidden_sizes"
,
type
=
str2int_list
,
default
=
default_hidden_sizes
)
p
.
add_argument
(
"--use_residual"
,
action
=
"store_true"
,
help
=
"Whether to use residual connection"
)
parser
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./configs/benchmark_ops/rmsnorm/"
,
help
=
"Path to save rmsnorm benchmark results"
,
)
args
=
parser
.
parse_args
()
# Run correctness test
calculate_diff
(
batch_size
=
4
,
seq_len
=
128
,
hidden_size
=
4096
,
use_residual
=
args
.
use_residual
)
p
.
add_argument
(
"--verify_only"
,
action
=
"store_true"
)
args
=
p
.
parse_args
()
# coerce lists
if
isinstance
(
args
.
batch_sizes
,
str
):
args
.
batch_sizes
=
str2int_list
(
args
.
batch_sizes
)
if
isinstance
(
args
.
seq_lens
,
str
):
args
.
seq_lens
=
str2int_list
(
args
.
seq_lens
)
if
isinstance
(
args
.
hidden_sizes
,
str
):
args
.
hidden_sizes
=
str2int_list
(
args
.
hidden_sizes
)
# patch perf_report grid
benchmark_grid
=
make_configs
(
args
.
batch_sizes
,
args
.
seq_lens
,
args
.
hidden_sizes
)
if
hasattr
(
benchmark
,
"benchmarks"
):
benchmark
.
benchmarks
.
x_vals
=
benchmark_grid
else
:
benchmark
.
benchmark
.
x_vals
=
benchmark_grid
# Get the benchmark function with proper use_residual setting
benchmark
=
get_benchmark
(
args
.
use_residual
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
if
args
.
verify_only
:
ok
=
calculate_diff
(
4
,
128
,
args
.
hidden_sizes
[
0
],
args
.
use_residual
)
print
(
"✅ sanity pass"
if
ok
else
"❌ mismatch"
)
else
:
benchmark
.
run
(
print_data
=
True
,
use_residual
=
args
.
use_residual
)
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
View file @
c4e314f9
...
...
@@ -114,7 +114,9 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
filter_apply_order
=
"joint"
,
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
[
0.5
,
0.2
,
0.8
])
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
fn
,
quantiles
=
[
0.5
,
0.2
,
0.8
]
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
...
...
sgl-kernel/tests/test_norm.py
View file @
c4e314f9
...
...
@@ -3,6 +3,7 @@
import
pytest
import
sgl_kernel
import
torch
from
sgl_kernel.utils
import
is_arch_support_pdl
def
llama_rms_norm
(
x
,
w
,
eps
=
1e-6
):
...
...
@@ -58,11 +59,12 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):
w
=
torch
.
randn
(
hidden_size
).
to
(
0
).
to
(
dtype
)
y_ref
=
llama_rms_norm
(
x
,
w
)
enable_pdl
=
is_arch_support_pdl
()
if
specify_out
:
y
=
torch
.
empty_like
(
x
)
sgl_kernel
.
rmsnorm
(
x
,
w
,
out
=
y
)
sgl_kernel
.
rmsnorm
(
x
,
w
,
out
=
y
,
enable_pdl
=
enable_pdl
)
else
:
y
=
sgl_kernel
.
rmsnorm
(
x
,
w
)
y
=
sgl_kernel
.
rmsnorm
(
x
,
w
,
enable_pdl
=
enable_pdl
)
torch
.
testing
.
assert_close
(
y_ref
,
y
,
rtol
=
1e-3
,
atol
=
1e-3
)
...
...
@@ -83,7 +85,10 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
x_fused
=
x
.
clone
()
residual_fused
=
residual
.
clone
()
sgl_kernel
.
fused_add_rmsnorm
(
x_fused
,
residual_fused
,
weight
,
eps
)
enable_pdl
=
is_arch_support_pdl
()
sgl_kernel
.
fused_add_rmsnorm
(
x_fused
,
residual_fused
,
weight
,
eps
,
enable_pdl
=
enable_pdl
)
torch
.
testing
.
assert_close
(
x_fused
,
x_native
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
residual_fused
,
residual_native
,
rtol
=
1e-3
,
atol
=
1e-3
)
...
...
@@ -98,11 +103,12 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
w
=
torch
.
randn
(
hidden_size
).
to
(
0
).
to
(
dtype
)
y_ref
=
gemma_rms_norm
(
x
,
w
)
enable_pdl
=
is_arch_support_pdl
()
if
specify_out
:
y
=
torch
.
empty_like
(
x
)
sgl_kernel
.
gemma_rmsnorm
(
x
,
w
,
out
=
y
)
sgl_kernel
.
gemma_rmsnorm
(
x
,
w
,
out
=
y
,
enable_pdl
=
enable_pdl
)
else
:
y
=
sgl_kernel
.
gemma_rmsnorm
(
x
,
w
)
y
=
sgl_kernel
.
gemma_rmsnorm
(
x
,
w
,
enable_pdl
=
enable_pdl
)
torch
.
testing
.
assert_close
(
y_ref
,
y
,
rtol
=
1e-3
,
atol
=
1e-3
)
...
...
@@ -123,7 +129,10 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
x_fused
=
x
.
clone
()
residual_fused
=
residual
.
clone
()
sgl_kernel
.
gemma_fused_add_rmsnorm
(
x_fused
,
residual_fused
,
weight
,
eps
)
enable_pdl
=
is_arch_support_pdl
()
sgl_kernel
.
gemma_fused_add_rmsnorm
(
x_fused
,
residual_fused
,
weight
,
eps
,
enable_pdl
=
enable_pdl
)
torch
.
testing
.
assert_close
(
x_fused
,
x_native
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
residual_fused
,
residual_native
,
rtol
=
1e-3
,
atol
=
1e-3
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment