Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11965b0d
"vscode:/vscode.git/clone" did not exist on "52f58fc42ab1f00ae3d0e0279594664c07504142"
Unverified
Commit
11965b0d
authored
Sep 29, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Sep 29, 2025
Browse files
Fix sgl-kernel benchmark dead code (#11022)
parent
71959545
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
242 additions
and
58 deletions
+242
-58
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
+63
-10
sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py
sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py
+31
-11
sgl-kernel/benchmark/bench_rmsnorm.py
sgl-kernel/benchmark/bench_rmsnorm.py
+102
-23
sgl-kernel/benchmark/bench_rotary_embedding.py
sgl-kernel/benchmark/bench_rotary_embedding.py
+20
-5
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
+26
-9
No files found.
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
View file @
11965b0d
import
itertools
import
os
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
triton.testing
from
sgl_kernel
import
sgl_per_token_quant_fp8
from
vllm
import
_custom_ops
as
ops
# Optional vLLM import
try
:
from
vllm
import
_custom_ops
as
ops
VLLM_AVAILABLE
=
True
except
ImportError
:
ops
=
None
VLLM_AVAILABLE
=
False
from
sglang.srt.utils
import
is_hip
_is_hip
=
is_hip
()
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
# Get correct FP8 E4M3 maximum value
...
...
@@ -49,6 +65,9 @@ def torch_per_token_quant_fp8(
def
vllm_per_token_quant_fp8
(
input
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
VLLM_AVAILABLE
:
# Fallback to SGLang implementation
return
sglang_per_token_quant_fp8
(
input
)
return
ops
.
scaled_fp8_quant
(
input
,
use_per_token_if_dynamic
=
True
)
...
...
@@ -74,6 +93,17 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
vllm_out
,
vllm_scale
=
vllm_per_token_quant_fp8
(
x
)
sglang_out
,
sglang_scale
=
sglang_per_token_quant_fp8
(
x
)
if
not
VLLM_AVAILABLE
:
print
(
"⚠️ vLLM not available, skipping vLLM comparison"
)
# Only compare Torch vs SGLang
torch_sglang_scale_diff
=
torch
.
abs
(
torch_scale
-
sglang_scale
).
mean
().
item
()
torch_sglang_out_diff
=
(
torch
.
abs
(
torch_out
.
float
()
-
sglang_out
.
float
()).
mean
().
item
()
)
print
(
f
"Scale difference (Torch vs SGLang):
{
torch_sglang_scale_diff
:.
8
f
}
"
)
print
(
f
"Output difference (Torch vs SGLang):
{
torch_sglang_out_diff
:.
8
f
}
"
)
return
print
(
f
"
\n
=== Comparison for hidden_dim=
{
hidden_dim
}
==="
)
# Compare scales
...
...
@@ -125,9 +155,15 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
print
(
f
" VLLM vs SGLang:
{
'✅'
if
vllm_sglang_match
else
'❌'
}
"
)
batch_size_range
=
[
16
,
32
,
64
,
128
]
seq_len_range
=
[
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
hidden_dim_range
=
[
1368
,
2048
,
4096
]
# CI environment uses simplified parameters
if
IS_CI
:
batch_size_range
=
[
16
]
# Single batch size for CI
seq_len_range
=
[
64
]
# Single sequence length for CI
hidden_dim_range
=
[
2048
]
# Single hidden dimension for CI
else
:
batch_size_range
=
[
16
,
32
,
64
,
128
]
seq_len_range
=
[
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
hidden_dim_range
=
[
1368
,
2048
,
4096
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
hidden_dim_range
))
...
...
@@ -137,9 +173,19 @@ configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_ran
x_names
=
[
"batch_size"
,
"seq_len"
,
"hidden_dim"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"torch"
,
"vllm"
,
"sglang"
],
line_names
=
[
"Torch Reference"
,
"VLLM"
,
"SGL Kernel"
],
styles
=
[(
"red"
,
"-"
),
(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
line_vals
=
(
[
"torch"
,
"vllm"
,
"sglang"
]
if
VLLM_AVAILABLE
else
[
"torch"
,
"sglang"
]
),
line_names
=
(
[
"Torch Reference"
,
"VLLM"
,
"SGL Kernel"
]
if
VLLM_AVAILABLE
else
[
"Torch Reference"
,
"SGL Kernel"
]
),
styles
=
(
[(
"red"
,
"-"
),
(
"blue"
,
"-"
),
(
"green"
,
"-"
)]
if
VLLM_AVAILABLE
else
[(
"red"
,
"-"
),
(
"green"
,
"-"
)]
),
ylabel
=
"us"
,
plot_name
=
"per-token-dynamic-quant-fp8-performance"
,
args
=
{},
...
...
@@ -156,6 +202,8 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
if
provider
==
"torch"
:
fn
=
lambda
:
torch_per_token_quant_fp8
(
x
.
clone
())
elif
provider
==
"vllm"
:
if
not
VLLM_AVAILABLE
:
return
(
0
,
0
,
0
)
fn
=
lambda
:
vllm_per_token_quant_fp8
(
x
.
clone
())
elif
provider
==
"sglang"
:
fn
=
lambda
:
sglang_per_token_quant_fp8
(
x
.
clone
())
...
...
@@ -166,11 +214,16 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
if
__name__
==
"__main__"
:
# Test various hidden dimensions for correctness
test_dims
=
[
1368
,
2048
,
4096
]
# Test various hidden dimensions for correctness - simplified for CI
if
IS_CI
:
test_dims
=
[
2048
]
# Single dimension for CI
batch_size
,
seq_len
=
4
,
64
# Smaller values for CI
else
:
test_dims
=
[
1368
,
2048
,
4096
]
batch_size
,
seq_len
=
4
,
4096
for
dim
in
test_dims
:
calculate_diff
(
batch_size
=
4
,
seq_len
=
4096
,
hidden_dim
=
dim
)
calculate_diff
(
batch_size
=
batch_size
,
seq_len
=
seq_len
,
hidden_dim
=
dim
)
print
(
"
\n
"
+
"="
*
60
)
print
(
"Starting performance benchmark..."
)
...
...
sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py
View file @
11965b0d
import
argparse
import
copy
import
itertools
import
os
import
torch
import
triton
...
...
@@ -10,6 +11,12 @@ from sgl_kernel import (
qserve_w4a8_per_group_gemm
,
)
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
def
to_int8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
...
...
@@ -65,10 +72,17 @@ WEIGHT_SHAPES = {
}
# CI environment uses simplified parameters
if
IS_CI
:
batch_sizes
=
[
1
,
16
]
# Simplified for CI
else
:
batch_sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
]
,
x_vals
=
batch_sizes
,
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"FP16"
,
"W8A8"
,
"Qserve_W4A8_Per_Channel"
,
"Qserve_W4A8_Per_Group"
],
...
...
@@ -184,13 +198,19 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
# Skip in CI environment
if
IS_CI
:
print
(
"Skipping QServe W4A8 GEMM benchmark in CI environment"
)
print
(
"QServe operations may have compatibility issues in CI"
)
else
:
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
sgl-kernel/benchmark/bench_rmsnorm.py
View file @
11965b0d
...
...
@@ -2,6 +2,7 @@
# (batch_size, seq_len, hidden_size) and prints speed-up.
import
argparse
import
itertools
import
os
import
re
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
...
@@ -10,9 +11,31 @@ import torch
import
torch.nn
as
nn
import
triton
import
triton.testing
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
from
sgl_kernel.utils
import
is_arch_support_pdl
from
vllm
import
_custom_ops
as
vllm_ops
# Optional imports
try
:
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
FLASHINFER_AVAILABLE
=
True
except
ImportError
:
fused_add_rmsnorm
=
None
rmsnorm
=
None
FLASHINFER_AVAILABLE
=
False
try
:
from
vllm
import
_custom_ops
as
vllm_ops
VLLM_AVAILABLE
=
True
except
ImportError
:
vllm_ops
=
None
VLLM_AVAILABLE
=
False
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
def
str2int_list
(
arg
:
str
)
->
List
[
int
]:
...
...
@@ -79,6 +102,10 @@ def rmsnorm_flashinfer(
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
eps
:
float
=
1e-6
,
):
if
not
FLASHINFER_AVAILABLE
:
# Fallback to naive implementation if FlashInfer is not available
return
rmsnorm_naive
(
x
,
weight
,
residual
,
eps
)
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
residual
is
not
None
:
...
...
@@ -103,6 +130,10 @@ def rmsnorm_vllm(
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
eps
:
float
=
1e-6
,
):
if
not
VLLM_AVAILABLE
:
# Fallback to naive implementation if vLLM is not available
return
rmsnorm_naive
(
x
,
weight
,
residual
,
eps
)
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
residual
is
not
None
:
...
...
@@ -179,37 +210,72 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
output_sglang
=
output_sglang
[
0
]
print
(
f
"Naive output=
{
output_naive
}
"
)
print
(
f
"FlashInfer output=
{
output_flashinfer
}
"
)
print
(
f
"VLLM output=
{
output_vllm
}
"
)
if
FLASHINFER_AVAILABLE
:
print
(
f
"FlashInfer output=
{
output_flashinfer
}
"
)
else
:
print
(
"FlashInfer not available, skipped"
)
if
VLLM_AVAILABLE
:
print
(
f
"VLLM output=
{
output_vllm
}
"
)
else
:
print
(
"vLLM not available, skipped"
)
print
(
f
"SGLang output=
{
output_sglang
}
"
)
if
(
torch
.
allclose
(
output_naive
,
output_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
output_naive
,
output_sglang
,
atol
=
1e-2
,
rtol
=
1e-2
)
):
print
(
"✅ All implementations match"
)
# Only compare available implementations
all_match
=
torch
.
allclose
(
output_naive
,
output_sglang
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
FLASHINFER_AVAILABLE
:
all_match
=
all_match
and
torch
.
allclose
(
output_naive
,
output_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
VLLM_AVAILABLE
:
all_match
=
all_match
and
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
all_match
:
print
(
"✅ All available implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
default_batch_sizes
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
# 1, 4, 16, 64
default_seq_lens
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
# 64, 128, 256, 512, 1024
default_hidden_sizes
=
[
32
*
128
,
48
*
128
]
# 4096, 6144
# CI environment uses simplified parameters
if
IS_CI
:
default_batch_sizes
=
[
1
]
# Single batch size for CI
default_seq_lens
=
[
64
]
# Single sequence length for CI
default_hidden_sizes
=
[
4096
]
# Single hidden size for CI
else
:
default_batch_sizes
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
# 1, 4, 16, 64
default_seq_lens
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
# 64, 128, 256, 512, 1024
default_hidden_sizes
=
[
32
*
128
,
48
*
128
]
# 4096, 6144
def
make_configs
(
bsizes
:
List
[
int
],
slens
:
List
[
int
],
hsizes
:
List
[
int
])
->
List
[
Tuple
]:
return
list
(
itertools
.
product
(
bsizes
,
slens
,
hsizes
))
# Filter providers based on availability
available_providers
=
[
"huggingface"
,
"sglang"
]
available_names
=
[
"HuggingFace"
,
"SGL Kernel"
]
available_styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)]
if
FLASHINFER_AVAILABLE
:
available_providers
.
insert
(
-
1
,
"flashinfer"
)
available_names
.
insert
(
-
1
,
"FlashInfer"
)
available_styles
.
insert
(
-
1
,
(
"green"
,
"-"
))
if
VLLM_AVAILABLE
:
available_providers
.
insert
(
-
1
,
"vllm"
)
available_names
.
insert
(
-
1
,
"vLLM"
)
available_styles
.
insert
(
-
1
,
(
"red"
,
"-"
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
,
"hidden_size"
],
x_vals
=
[],
line_arg
=
"provider"
,
line_vals
=
[
"huggingface"
,
"flashinfer"
,
"vllm"
,
"sglang"
]
,
line_names
=
[
"HuggingFace"
,
"FlashInfer"
,
"vLLM"
,
"SGL Kernel"
]
,
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
),
(
"orange"
,
"-"
)]
,
line_vals
=
available_providers
,
line_names
=
available_names
,
styles
=
available_styles
,
ylabel
=
"µs (median) or × (speed-up)"
,
plot_name
=
"rmsnorm-performance"
,
args
=
{},
...
...
@@ -242,6 +308,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
)
)
elif
provider
==
"flashinfer"
:
if
not
FLASHINFER_AVAILABLE
:
return
(
0
,
0
,
0
)
return
timed
(
lambda
:
rmsnorm_flashinfer
(
x
.
clone
(),
...
...
@@ -250,6 +318,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
)
)
elif
provider
==
"vllm"
:
if
not
VLLM_AVAILABLE
:
return
(
0
,
0
,
0
)
return
timed
(
lambda
:
rmsnorm_vllm
(
x
.
clone
(),
...
...
@@ -267,13 +337,22 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
)
# provider == "speedup"
t_ref
,
_
,
_
=
timed
(
lambda
:
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
if
VLLM_AVAILABLE
:
t_ref
,
_
,
_
=
timed
(
lambda
:
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
)
)
else
:
t_ref
,
_
,
_
=
timed
(
lambda
:
rmsnorm_naive
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
)
)
)
t_sgl
,
_
,
_
=
timed
(
lambda
:
rmsnorm_sglang
(
x
.
clone
(),
...
...
@@ -281,7 +360,7 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
residual
.
clone
()
if
residual
is
not
None
else
None
,
)
)
spd
=
t_ref
/
t_sgl
spd
=
t_ref
/
t_sgl
if
t_ref
>
0
else
1.0
return
(
spd
,
spd
,
spd
)
...
...
sgl-kernel/benchmark/bench_rotary_embedding.py
View file @
11965b0d
import
itertools
import
os
import
torch
import
triton
...
...
@@ -12,17 +13,31 @@ from sgl_kernel.testing.rotary_embedding import (
from
sglang.srt.bench_utils
import
bench_kineto
configs
=
[
(
batch_size
,
seq_len
,
save_kv_cache
)
for
batch_size
,
seq_len
in
(
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
# CI environment uses simplified parameters
if
IS_CI
:
batch_seq_configs
=
[(
1
,
1
)]
# Single config for CI
save_kv_configs
=
[
False
]
# Single option for CI
else
:
batch_seq_configs
=
[
(
1
,
1
),
(
32
,
1
),
(
128
,
1
),
(
512
,
1
),
(
2
,
512
),
(
4
,
4096
),
)
for
save_kv_cache
in
(
False
,
True
)
]
save_kv_configs
=
[
False
,
True
]
configs
=
[
(
batch_size
,
seq_len
,
save_kv_cache
)
for
batch_size
,
seq_len
in
batch_seq_configs
for
save_kv_cache
in
save_kv_configs
]
...
...
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
View file @
11965b0d
import
itertools
import
os
import
sgl_kernel
import
torch
import
triton
import
triton.testing
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
def
torch_top_k_top_p_joint_sampling_from_probs
(
normalized_prob
,
top_k
,
top_p
,
eps
=
1e-4
...
...
@@ -67,10 +74,16 @@ def calculate_diff(batch_size, vocab_size, p):
)
# parameter space
batch_size_range
=
[
16
,
64
,
128
]
vocab_size_range
=
[
111
,
32000
]
p_range
=
[
0.1
,
0.5
]
# parameter space - simplified for CI
if
IS_CI
:
batch_size_range
=
[
16
]
# Single batch size for CI
vocab_size_range
=
[
111
]
# Single vocab size for CI
p_range
=
[
0.1
]
# Single p value for CI
else
:
batch_size_range
=
[
16
,
64
,
128
]
vocab_size_range
=
[
111
,
32000
]
p_range
=
[
0.1
,
0.5
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
vocab_size_range
,
p_range
))
...
...
@@ -114,15 +127,19 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
filter_apply_order
=
"joint"
,
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
fn
,
quantiles
=
[
0.5
,
0.2
,
0.8
]
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
[
0.5
,
0.2
,
0.8
])
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
# Correctness check
for
cfg
in
configs
:
# Correctness check - simplified for CI
if
IS_CI
:
# Only test one configuration in CI
test_configs
=
[
configs
[
0
]]
if
configs
else
[(
16
,
111
,
0.1
)]
else
:
test_configs
=
configs
for
cfg
in
test_configs
:
calculate_diff
(
*
cfg
)
print
(
"
\n
"
+
"="
*
60
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment