Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11965b0d
Unverified
Commit
11965b0d
authored
Sep 29, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Sep 29, 2025
Browse files
Fix sgl-kernel benchmark dead code (#11022)
parent
71959545
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
779 additions
and
204 deletions
+779
-204
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+45
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-1
sgl-kernel/benchmark/bench_activation.py
sgl-kernel/benchmark/bench_activation.py
+55
-14
sgl-kernel/benchmark/bench_awq_dequant.py
sgl-kernel/benchmark/bench_awq_dequant.py
+42
-7
sgl-kernel/benchmark/bench_cutlass_mla.py
sgl-kernel/benchmark/bench_cutlass_mla.py
+47
-12
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
+23
-4
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
+49
-18
sgl-kernel/benchmark/bench_fp4_gemm.py
sgl-kernel/benchmark/bench_fp4_gemm.py
+64
-31
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
+65
-7
sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py
sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py
+45
-30
sgl-kernel/benchmark/bench_fp8_gemm.py
sgl-kernel/benchmark/bench_fp8_gemm.py
+68
-18
sgl-kernel/benchmark/bench_int8_gemm.py
sgl-kernel/benchmark/bench_int8_gemm.py
+52
-11
sgl-kernel/benchmark/bench_lightning_attention_decode.py
sgl-kernel/benchmark/bench_lightning_attention_decode.py
+16
-3
sgl-kernel/benchmark/bench_moe_align_block_size.py
sgl-kernel/benchmark/bench_moe_align_block_size.py
+39
-15
sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
+14
-1
sgl-kernel/benchmark/bench_moe_fused_gate.py
sgl-kernel/benchmark/bench_moe_fused_gate.py
+13
-1
sgl-kernel/benchmark/bench_moe_topk_softmax.py
sgl-kernel/benchmark/bench_moe_topk_softmax.py
+63
-16
sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py
sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py
+27
-5
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
+31
-3
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
+20
-6
No files found.
.github/workflows/pr-test.yml
View file @
11965b0d
...
...
@@ -155,6 +155,50 @@ jobs:
cd test/srt
python3 test_mla_deepseek_v3.py
sgl-kernel-benchmark-test
:
needs
:
[
check-changes
,
sgl-kernel-build-wheels
]
if
:
always() && !failure() && !cancelled()
runs-on
:
1-gpu-runner
env
:
HF_TOKEN
:
${{ secrets.HF_TOKEN }}
CI
:
true
steps
:
-
uses
:
actions/checkout@v4
-
name
:
Cleanup
run
:
|
ls -alh sgl-kernel/dist || true
rm -rf sgl-kernel/dist/* || true
-
name
:
Download artifacts
uses
:
actions/download-artifact@v4
with
:
path
:
sgl-kernel/dist/
merge-multiple
:
true
pattern
:
wheel-python3.10-cuda12.9
-
name
:
Install dependencies
run
:
|
CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/ci_install_dependency.sh
-
name
:
Run benchmark tests
timeout-minutes
:
45
run
:
|
cd sgl-kernel/benchmark
echo "Running sgl-kernel benchmark tests in CI mode..."
echo "CI environment variable: $CI"
echo "GITHUB_ACTIONS environment variable: $GITHUB_ACTIONS"
for bench_file in bench_*.py; do
echo "Testing $bench_file..."
timeout 60 python3 "$bench_file" || echo "Warning: $bench_file timed out or failed, continuing..."
echo "Completed $bench_file"
echo "---"
done
echo "All benchmark tests completed!"
# =============================================== primary ====================================================
unit-test-frontend
:
...
...
@@ -647,7 +691,7 @@ jobs:
check-changes
,
sgl-kernel-build-wheels
,
sgl-kernel-unit-test
,
sgl-kernel-mla-test
,
sgl-kernel-unit-test
,
sgl-kernel-mla-test
,
sgl-kernel-benchmark-test
,
unit-test-frontend
,
unit-test-backend-1-gpu
,
unit-test-backend-2-gpu
,
unit-test-backend-4-gpu
,
unit-test-backend-8-gpu
,
...
...
python/sglang/srt/utils.py
View file @
11965b0d
...
...
@@ -2460,7 +2460,7 @@ class BumpAllocator:
def
log_info_on_rank0
(
logger
,
msg
):
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
if
get_tensor_model_parallel_rank
()
==
0
:
if
torch
.
distributed
.
is_initialized
()
and
get_tensor_model_parallel_rank
()
==
0
:
logger
.
info
(
msg
)
...
...
sgl-kernel/benchmark/bench_activation.py
View file @
11965b0d
...
...
@@ -2,6 +2,7 @@
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
import
argparse
import
itertools
import
os
import
re
from
typing
import
List
,
Tuple
...
...
@@ -11,7 +12,21 @@ import torch.nn.functional as F
import
triton
import
triton.testing
from
sgl_kernel
import
gelu_and_mul
,
gelu_tanh_and_mul
,
silu_and_mul
from
vllm
import
_custom_ops
as
vllm_ops
# Optional vLLM import
try
:
from
vllm
import
_custom_ops
as
vllm_ops
VLLM_AVAILABLE
=
True
except
ImportError
:
vllm_ops
=
None
VLLM_AVAILABLE
=
False
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
# gelu_quick is only available on HIP/ROCm platforms
try
:
...
...
@@ -22,7 +37,7 @@ except ImportError:
GELU_QUICK_AVAILABLE
=
False
gelu_quick
=
None
if
not
hasattr
(
vllm_ops
,
"silu_and_mul"
):
if
VLLM_AVAILABLE
and
not
hasattr
(
vllm_ops
,
"silu_and_mul"
):
vllm_ops
=
torch
.
ops
.
_C
...
...
@@ -40,6 +55,13 @@ def calculate_diff(
"""Compare vLLM with SGLang for one shape."""
device
=
torch
.
device
(
"cuda"
)
if
not
VLLM_AVAILABLE
:
print
(
f
"[
{
kernel
:
14
s
}
|
{
str
(
dtype
):
9
s
}
| B=
{
batch_size
:
3
d
}
| "
f
"L=
{
seq_len
:
3
d
}
| D=
{
dim
:
5
d
}
] ⚠️ vLLM not available, skipping comparison"
)
return
True
# activation-only quick GELU
if
kernel
==
"gelu_quick"
:
if
not
GELU_QUICK_AVAILABLE
:
...
...
@@ -68,19 +90,30 @@ def calculate_diff(
return
ok
kernels
=
[
"silu_and_mul"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
]
if
GELU_QUICK_AVAILABLE
:
kernels
.
append
(
"gelu_quick"
)
dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
]
# CI environment uses simplified parameters for kernels and dtypes too
if
IS_CI
:
kernels
=
[
"silu_and_mul"
]
# Only test one kernel in CI
dtypes
=
[
torch
.
float16
]
# Only test one dtype in CI
else
:
kernels
=
[
"silu_and_mul"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
]
if
GELU_QUICK_AVAILABLE
:
kernels
.
append
(
"gelu_quick"
)
dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
]
def
make_configs
(
bsizes
:
List
[
int
],
slens
:
List
[
int
],
dims_
:
List
[
int
])
->
List
[
Tuple
]:
return
list
(
itertools
.
product
(
kernels
,
dtypes
,
bsizes
,
slens
,
dims_
))
default_batch_sizes
=
[
2
**
i
for
i
in
range
(
0
,
5
,
2
)]
# 1,4,16
default_seq_lens
=
[
2
**
i
for
i
in
range
(
0
,
8
,
2
)]
# 1,4,16,64
default_dims
=
[
2
**
i
for
i
in
range
(
10
,
15
)]
# 1024...16384
# CI environment uses simplified parameters
if
IS_CI
:
default_batch_sizes
=
[
1
]
# Single batch size for CI
default_seq_lens
=
[
1
]
# Single sequence length for CI
default_dims
=
[
1024
]
# Single dimension for CI
else
:
default_batch_sizes
=
[
2
**
i
for
i
in
range
(
0
,
5
,
2
)]
# 1,4,16
default_seq_lens
=
[
2
**
i
for
i
in
range
(
0
,
8
,
2
)]
# 1,4,16,64
default_dims
=
[
2
**
i
for
i
in
range
(
10
,
15
)]
# 1024...16384
@
triton
.
testing
.
perf_report
(
...
...
@@ -102,16 +135,24 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
x
=
torch
.
randn
(
batch_size
,
seq_len
,
in_mult
*
dim
,
dtype
=
dtype
,
device
=
device
)
y0
=
torch
.
zeros
(
batch_size
,
seq_len
,
dim
,
dtype
=
dtype
,
device
=
device
)
vllm_kernel
=
getattr
(
vllm_ops
,
kernel
)
if
not
VLLM_AVAILABLE
and
provider
in
[
"vllm"
,
"speedup"
]:
# Skip vLLM-related benchmarks if vLLM is not available
return
(
0
,
0
,
0
)
if
VLLM_AVAILABLE
:
vllm_kernel
=
getattr
(
vllm_ops
,
kernel
)
if
kernel
==
"gelu_quick"
and
not
GELU_QUICK_AVAILABLE
:
# Skip benchmark for gelu_quick if not available
return
(
0
,
0
,
0
)
sglang_kernel
=
getattr
(
sgl_kernel
,
kernel
)
def
baseline
():
tmp
=
y0
.
clone
()
vllm_kernel
(
tmp
,
x
)
return
tmp
if
VLLM_AVAILABLE
:
tmp
=
y0
.
clone
()
vllm_kernel
(
tmp
,
x
)
return
tmp
else
:
return
torch
.
zeros_like
(
y0
)
def
sglang
():
return
sglang_kernel
(
x
)
...
...
@@ -134,7 +175,7 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
# provider == "speedup"
t_ref
,
_
,
_
=
timed
(
baseline
)
t_sgl
,
_
,
_
=
timed
(
sglang
)
spd
=
t_ref
/
t_sgl
spd
=
t_ref
/
t_sgl
if
t_ref
>
0
else
1.0
return
(
spd
,
spd
,
spd
)
...
...
sgl-kernel/benchmark/bench_awq_dequant.py
View file @
11965b0d
import
itertools
import
os
from
typing
import
List
,
Tuple
import
torch
import
triton
import
triton.testing
from
sgl_kernel
import
awq_dequantize
from
vllm
import
_custom_ops
as
ops
# Optional vLLM import
try
:
from
vllm
import
_custom_ops
as
ops
VLLM_AVAILABLE
=
True
except
ImportError
:
ops
=
None
VLLM_AVAILABLE
=
False
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
def
vllm_awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
VLLM_AVAILABLE
:
# Fallback to SGLang implementation
return
sglang_awq_dequantize
(
qweight
,
scales
,
qzeros
)
return
ops
.
awq_dequantize
(
qweight
,
scales
,
qzeros
,
0
,
0
,
0
)
...
...
@@ -43,6 +61,10 @@ def calculate_diff(qweight_row: int, qweight_col: int):
device
=
device
,
)
if
not
VLLM_AVAILABLE
:
print
(
"⚠️ vLLM not available, skipping comparison"
)
return
vllm_out
=
vllm_awq_dequantize
(
qweight
,
scales
,
qzeros
)
sglang_out
=
sglang_awq_dequantize
(
qweight
,
scales
,
qzeros
)
...
...
@@ -56,8 +78,13 @@ def calculate_diff(qweight_row: int, qweight_col: int):
print
(
"❌ Implementations differ"
)
qweight_row_range
=
[
3584
,
18944
,
128
,
256
,
512
,
1024
]
qweight_cols_range
=
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
]
# CI environment uses simplified parameters
if
IS_CI
:
qweight_row_range
=
[
128
]
# Single row size for CI
qweight_cols_range
=
[
16
]
# Single column size for CI
else
:
qweight_row_range
=
[
3584
,
18944
,
128
,
256
,
512
,
1024
]
qweight_cols_range
=
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
]
configs
=
list
(
itertools
.
product
(
qweight_row_range
,
qweight_cols_range
))
...
...
@@ -67,9 +94,9 @@ configs = list(itertools.product(qweight_row_range, qweight_cols_range))
x_names
=
[
"qweight_row"
,
"qweight_col"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"sglang"
],
line_names
=
[
"VLLM"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
line_vals
=
[
"vllm"
,
"sglang"
]
if
VLLM_AVAILABLE
else
[
"sglang"
]
,
line_names
=
[
"VLLM"
,
"SGL
Kernel"
]
if
VLLM_AVAILABLE
else
[
"SGL
Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)]
if
VLLM_AVAILABLE
else
[(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"awq-dequantize-performance"
,
args
=
{},
...
...
@@ -100,6 +127,8 @@ def benchmark(qweight_row, qweight_col, provider):
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"vllm"
:
if
not
VLLM_AVAILABLE
:
return
(
0
,
0
,
0
)
fn
=
lambda
:
vllm_awq_dequantize
(
qweight
.
clone
(),
scales
.
clone
(),
qzeros
.
clone
()
)
...
...
@@ -114,5 +143,11 @@ def benchmark(qweight_row, qweight_col, provider):
if
__name__
==
"__main__"
:
calculate_diff
(
qweight_row
=
3584
,
qweight_col
=
448
)
# Simplify for CI environment
if
IS_CI
:
qweight_row
,
qweight_col
=
128
,
16
# Smaller values for CI
else
:
qweight_row
,
qweight_col
=
3584
,
448
calculate_diff
(
qweight_row
=
qweight_row
,
qweight_col
=
qweight_col
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/benchmark/bench_cutlass_mla.py
View file @
11965b0d
import
argparse
import
copy
import
itertools
import
os
import
torch
import
triton
from
sgl_kernel
import
cutlass_mla_decode
,
cutlass_mla_get_workspace_size
bs_range
=
[
1
,
8
,
32
,
64
,
128
,
256
]
qlen_range
=
[
1
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
from
sglang.srt.utils
import
get_device_capability
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
# CI environment uses simplified parameters
if
IS_CI
:
bs_range
=
[
1
]
# Single batch size for CI
qlen_range
=
[
64
]
# Single sequence length for CI
else
:
bs_range
=
[
1
,
8
,
32
,
64
,
128
,
256
]
qlen_range
=
[
1
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
configs
=
list
(
itertools
.
product
(
bs_range
,
qlen_range
))
...
...
@@ -131,13 +145,34 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
for
block_size
in
args
.
block_sizes
:
for
kv_split
in
args
.
num_kv_splits
:
print
(
f
"block_size=
{
block_size
}
, num_kv_splits=
{
kv_split
}
: "
)
benchmark
.
run
(
print_data
=
True
,
block_size
=
block_size
,
num_kv_splits
=
kv_split
,
)
print
(
"Benchmark finished!"
)
# Skip in CI environment or unsupported architectures
if
IS_CI
:
major
,
minor
=
get_device_capability
()
if
major
is
None
or
major
<
10
:
# Requires compute capability 10.0+
print
(
"Skipping Cutlass MLA benchmark in CI environment"
)
if
major
is
not
None
:
print
(
f
"Cutlass MLA requires compute capability 10.0+, but found
{
major
}
.
{
minor
}
"
)
else
:
print
(
"Could not determine device capability"
)
else
:
for
block_size
in
args
.
block_sizes
:
for
kv_split
in
args
.
num_kv_splits
:
print
(
f
"block_size=
{
block_size
}
, num_kv_splits=
{
kv_split
}
: "
)
benchmark
.
run
(
print_data
=
True
,
block_size
=
block_size
,
num_kv_splits
=
kv_split
,
)
print
(
"Benchmark finished!"
)
else
:
for
block_size
in
args
.
block_sizes
:
for
kv_split
in
args
.
num_kv_splits
:
print
(
f
"block_size=
{
block_size
}
, num_kv_splits=
{
kv_split
}
: "
)
benchmark
.
run
(
print_data
=
True
,
block_size
=
block_size
,
num_kv_splits
=
kv_split
,
)
print
(
"Benchmark finished!"
)
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
View file @
11965b0d
import
argparse
import
os
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
import
torch
import
torch.nn.functional
as
F
...
...
@@ -6,16 +13,28 @@ import triton
import
triton.testing
from
sgl_kernel
import
dsv3_fused_a_gemm
# CI environment uses simplified parameters
if
IS_CI
:
num_tokens_vals
=
[
1
]
# Only test 1 value in CI
line_vals
=
[
"sgl-kernel"
]
# Only test sgl-kernel implementation in CI
else
:
num_tokens_vals
=
[
i
+
1
for
i
in
range
(
16
)]
# Test 1-16 in full mode
line_vals
=
[
"torch"
,
"sgl-kernel"
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
x_vals
=
[
i
+
1
for
i
in
range
(
16
)]
,
x_vals
=
num_tokens_vals
,
x_log
=
False
,
line_arg
=
"impl"
,
line_vals
=
[
"torch"
,
"sgl-kernel"
],
line_names
=
[
"torch (bf16)"
,
"dsv3_fused_a_gemm"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)],
line_vals
=
line_vals
,
line_names
=
(
[
"torch (bf16)"
,
"dsv3_fused_a_gemm"
]
if
not
IS_CI
else
[
"dsv3_fused_a_gemm"
]
),
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)]
if
not
IS_CI
else
[(
"orange"
,
"-"
)],
ylabel
=
"TFLOPs"
,
plot_name
=
"bf16 dsv3 fused a GEMM throughput"
,
args
=
{},
...
...
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
View file @
11965b0d
import
argparse
import
os
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
import
torch
import
torch.nn.functional
as
F
...
...
@@ -6,21 +13,37 @@ import triton
import
triton.testing
from
sgl_kernel
import
dsv3_router_gemm
# CI environment uses simplified parameters
if
IS_CI
:
num_tokens_vals
=
[
1
]
# Only test 1 value in CI
line_vals
=
[
"sgl-kernel-256"
]
# Only test one implementation in CI
else
:
num_tokens_vals
=
[
i
+
1
for
i
in
range
(
16
)]
# Test 1-16 in full mode
line_vals
=
[
"torch-256"
,
"sgl-kernel-256"
,
"torch-384"
,
"sgl-kernel-384"
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
x_vals
=
[
i
+
1
for
i
in
range
(
16
)]
,
x_vals
=
num_tokens_vals
,
x_log
=
False
,
line_arg
=
"impl"
,
line_vals
=
[
"torch-256"
,
"sgl-kernel-256"
,
"torch-384"
,
"sgl-kernel-384"
],
line_names
=
[
"torch-256"
,
"dsv3_router_gemm-256"
,
"torch-384"
,
"dsv3_router_gemm-384"
,
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
line_vals
=
line_vals
,
line_names
=
(
[
"torch-256"
,
"dsv3_router_gemm-256"
,
"torch-384"
,
"dsv3_router_gemm-384"
,
]
if
not
IS_CI
else
[
"dsv3_router_gemm-256"
]
),
styles
=
(
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)]
if
not
IS_CI
else
[(
"orange"
,
"-"
)]
),
ylabel
=
"TFLOPs"
,
plot_name
=
"input-bf16-output-bf16 dsv3 router gemm throughput"
,
args
=
{},
...
...
@@ -64,17 +87,25 @@ def benchmark_bf16_output(num_tokens, impl):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
x_vals
=
[
i
+
1
for
i
in
range
(
16
)]
,
x_vals
=
num_tokens_vals
,
x_log
=
False
,
line_arg
=
"impl"
,
line_vals
=
[
"torch-256"
,
"sgl-kernel-256"
,
"torch-384"
,
"sgl-kernel-384"
],
line_names
=
[
"torch-256"
,
"dsv3_router_gemm-256"
,
"torch-384"
,
"dsv3_router_gemm-384"
,
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
line_vals
=
line_vals
,
line_names
=
(
[
"torch-256"
,
"dsv3_router_gemm-256"
,
"torch-384"
,
"dsv3_router_gemm-384"
,
]
if
not
IS_CI
else
[
"dsv3_router_gemm-256"
]
),
styles
=
(
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)]
if
not
IS_CI
else
[(
"orange"
,
"-"
)]
),
ylabel
=
"TFLOPs"
,
plot_name
=
"input-bf16-output-fp32 dsv3 router gemm throughput"
,
args
=
{},
...
...
sgl-kernel/benchmark/bench_fp4_gemm.py
View file @
11965b0d
...
...
@@ -2,6 +2,7 @@ import argparse
import
copy
import
csv
import
itertools
import
os
import
pytest
import
torch
...
...
@@ -9,6 +10,14 @@ import triton
from
flashinfer
import
mm_fp4
from
sgl_kernel
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
sglang.srt.utils
import
get_device_capability
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
FLOAT4_E2M1_MAX
=
6.0
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
...
...
@@ -33,27 +42,34 @@ def get_weight_shapes(args):
]
# CI environment uses simplified parameters
if
IS_CI
:
batch_sizes
=
[
1
,
8
]
# Simplified for CI
else
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
3072
,
4096
,
8192
,
16384
,
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
3072
,
4096
,
8192
,
16384
,
],
x_vals
=
batch_sizes
,
# x_vals = [64],
x_log
=
False
,
line_arg
=
"provider"
,
...
...
@@ -188,21 +204,38 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
# Simplify for CI environment
if
IS_CI
:
args
.
tp_sizes
=
[
args
.
tp_sizes
[
0
]]
# Use only first TP size
if
args
.
csv
:
with
open
(
args
.
csv
,
"w"
,
newline
=
""
)
as
f
:
writer
=
csv
.
writer
(
f
)
writer
.
writerow
([
"provider"
,
"m"
,
"n"
,
"k"
,
"time_ms"
])
NKs
=
get_weight_shapes
(
args
)
for
N
,
K
in
NKs
:
print
(
f
"DeepSeek-R1-0528-FP4 N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
N
=
N
,
K
=
K
,
dtype
=
args
.
dtype
,
correctness
=
args
.
correctness
,
csv_file
=
args
.
csv
,
)
print
(
"Benchmark finished!"
)
# Check architecture compatibility - FP4 operations require sm100a/sm103a
major
,
minor
=
get_device_capability
()
if
major
is
None
or
major
<
10
:
# Requires compute capability 10.0+ (sm100a/sm103a)
print
(
"Skipping FP4 GEMM benchmark"
)
if
major
is
not
None
:
print
(
f
"FP4 operations require sm100a/sm103a, but found sm
{
major
}{
minor
}
"
)
else
:
print
(
"Could not determine device capability"
)
else
:
NKs
=
get_weight_shapes
(
args
)
# Limit iterations in CI
if
IS_CI
:
NKs
=
NKs
[:
2
]
# Only test first 2 shapes in CI
for
N
,
K
in
NKs
:
print
(
f
"DeepSeek-R1-0528-FP4 N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
N
=
N
,
K
=
K
,
dtype
=
args
.
dtype
,
correctness
=
args
.
correctness
,
csv_file
=
args
.
csv
,
)
print
(
"Benchmark finished!"
)
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
View file @
11965b0d
import
argparse
import
copy
import
itertools
import
os
import
deep_gemm
import
torch
import
triton
from
deep_gemm.utils.layout
import
get_mn_major_tma_aligned_tensor
from
sgl_kernel
import
fp8_blockwise_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
# Optional vLLM import
try
:
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
VLLM_AVAILABLE
=
True
except
ImportError
:
vllm_scaled_mm
=
None
VLLM_AVAILABLE
=
False
from
sglang.srt.layers.quantization.fp8_kernel
import
(
w8a8_block_fp8_matmul_triton
as
w8a8_block_fp8_matmul
,
)
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
def
get_weight_shapes
(
args
):
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
...
...
@@ -80,15 +95,46 @@ def scale_shape(shape, group_shape):
return
tuple
(
cdiv
(
shape
[
i
],
group_shape
[
i
])
for
i
in
range
(
len
(
group_shape
)))
# CI environment uses simplified parameters
if
IS_CI
:
batch_sizes
=
[
1
,
8
]
# Simplified for CI
else
:
batch_sizes
=
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
# Filter providers based on availability
available_providers
=
[
"sgl-kernel"
]
available_names
=
[
"sgl-kernel"
]
available_styles
=
[(
"orange"
,
"-"
)]
if
VLLM_AVAILABLE
:
available_providers
.
insert
(
0
,
"vllm"
)
available_names
.
insert
(
0
,
"vllm"
)
available_styles
.
insert
(
0
,
(
"blue"
,
"-"
))
available_providers
.
append
(
"triton"
)
available_names
.
append
(
"sglang triton"
)
available_styles
.
append
((
"red"
,
"-"
))
# Add deepgemm if available
try
:
import
deep_gemm
available_providers
.
append
(
"deepgemm"
)
available_names
.
append
(
"deepgemm"
)
available_styles
.
append
((
"yellow"
,
"-"
))
except
ImportError
:
pass
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
,
x_vals
=
batch_sizes
,
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"sgl-kernel"
,
"triton"
,
"deepgemm"
]
,
line_names
=
[
"vllm"
,
"sgl-kernel"
,
"sglang triton"
,
"deepgemm"
]
,
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
),
(
"red"
,
"-"
),
(
"yellow"
,
"-"
)]
,
line_vals
=
available_providers
,
line_names
=
available_names
,
styles
=
available_styles
,
ylabel
=
"GB/s"
,
plot_name
=
"fp8 blockwise scaled matmul"
,
args
=
{},
...
...
@@ -123,14 +169,16 @@ def benchmark(batch_size, provider, N, K):
),
quantiles
=
quantiles
,
)
if
provider
==
"vllm"
:
elif
provider
==
"vllm"
:
if
not
VLLM_AVAILABLE
:
return
(
0
,
0
,
0
)
scale_a
=
scale_a
.
t
().
contiguous
().
t
()
b_fp8
,
scale_b
=
b_fp8
.
t
(),
scale_b
.
t
()
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
torch
.
float16
),
quantiles
=
quantiles
,
)
if
provider
==
"triton"
:
el
if
provider
==
"triton"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
w8a8_block_fp8_matmul
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
[
128
,
128
],
torch
.
float16
...
...
@@ -166,7 +214,17 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
# Simplify for CI environment
if
IS_CI
:
args
.
models
=
[
args
.
models
[
0
]]
# Use only first model
args
.
tp_sizes
=
[
args
.
tp_sizes
[
0
]]
# Use only first TP size
NK_model_names
=
get_weight_shapes
(
args
)
# Limit iterations in CI
if
IS_CI
:
NK_model_names
=
NK_model_names
[:
2
]
# Only test first 2 shapes in CI
for
N
,
K
,
model_name
in
NK_model_names
:
if
N
%
128
!=
0
or
K
%
128
!=
0
:
print
(
f
"Skip
{
N
=
}
,
{
K
=
}
now"
)
...
...
sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py
View file @
11965b0d
import
argparse
import
os
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
import
random
from
dataclasses
import
dataclass
from
typing
import
List
,
Tuple
...
...
@@ -290,36 +297,44 @@ def main():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-warmup"
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
"--num-run"
,
type
=
int
,
default
=
10
)
shape_args
=
[
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
ShapeArg
(
expected_m_per_group
=
128
,
n
=
512
,
k
=
7168
,
num_groups
=
256
),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8
ShapeArg
(
expected_m_per_group
=
256
,
n
=
512
,
k
=
7168
,
num_groups
=
256
),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16
ShapeArg
(
expected_m_per_group
=
256
,
n
=
256
,
k
=
7168
,
num_groups
=
256
),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16
ShapeArg
(
expected_m_per_group
=
512
,
n
=
256
,
k
=
7168
,
num_groups
=
256
),
# Decode, DeepSeek-R1, gateup, bs = 32, TP = 8
ShapeArg
(
expected_m_per_group
=
1
,
n
=
512
,
k
=
7168
,
num_groups
=
256
),
# Decode, DeepSeek-R1, gateup, bs = 64, TP = 16
ShapeArg
(
expected_m_per_group
=
2
,
n
=
256
,
k
=
7168
,
num_groups
=
256
),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8
ShapeArg
(
expected_m_per_group
=
256
,
n
=
4096
,
k
=
7168
,
num_groups
=
32
),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16
ShapeArg
(
expected_m_per_group
=
512
,
n
=
4096
,
k
=
7168
,
num_groups
=
16
),
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
ShapeArg
(
expected_m_per_group
=
4
,
n
=
4096
,
k
=
7168
,
num_groups
=
32
),
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
ShapeArg
(
expected_m_per_group
=
8
,
n
=
4096
,
k
=
7168
,
num_groups
=
16
),
# Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4
ShapeArg
(
expected_m_per_group
=
1024
,
n
=
768
,
k
=
4096
,
num_groups
=
128
),
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4
ShapeArg
(
expected_m_per_group
=
1024
,
n
=
4096
,
k
=
384
,
num_groups
=
128
),
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4
ShapeArg
(
expected_m_per_group
=
16
,
n
=
768
,
k
=
4096
,
num_groups
=
128
),
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4
ShapeArg
(
expected_m_per_group
=
16
,
n
=
4096
,
k
=
384
,
num_groups
=
128
),
]
# CI environment uses simplified parameters
if
IS_CI
:
shape_args
=
[
# Only test one simple shape in CI
ShapeArg
(
expected_m_per_group
=
128
,
n
=
512
,
k
=
7168
,
num_groups
=
256
),
]
else
:
shape_args
=
[
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
ShapeArg
(
expected_m_per_group
=
128
,
n
=
512
,
k
=
7168
,
num_groups
=
256
),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8
ShapeArg
(
expected_m_per_group
=
256
,
n
=
512
,
k
=
7168
,
num_groups
=
256
),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16
ShapeArg
(
expected_m_per_group
=
256
,
n
=
256
,
k
=
7168
,
num_groups
=
256
),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16
ShapeArg
(
expected_m_per_group
=
512
,
n
=
256
,
k
=
7168
,
num_groups
=
256
),
# Decode, DeepSeek-R1, gateup, bs = 32, TP = 8
ShapeArg
(
expected_m_per_group
=
1
,
n
=
512
,
k
=
7168
,
num_groups
=
256
),
# Decode, DeepSeek-R1, gateup, bs = 64, TP = 16
ShapeArg
(
expected_m_per_group
=
2
,
n
=
256
,
k
=
7168
,
num_groups
=
256
),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8
ShapeArg
(
expected_m_per_group
=
256
,
n
=
4096
,
k
=
7168
,
num_groups
=
32
),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16
ShapeArg
(
expected_m_per_group
=
512
,
n
=
4096
,
k
=
7168
,
num_groups
=
16
),
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
ShapeArg
(
expected_m_per_group
=
4
,
n
=
4096
,
k
=
7168
,
num_groups
=
32
),
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
ShapeArg
(
expected_m_per_group
=
8
,
n
=
4096
,
k
=
7168
,
num_groups
=
16
),
# Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4
ShapeArg
(
expected_m_per_group
=
1024
,
n
=
768
,
k
=
4096
,
num_groups
=
128
),
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4
ShapeArg
(
expected_m_per_group
=
1024
,
n
=
4096
,
k
=
384
,
num_groups
=
128
),
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4
ShapeArg
(
expected_m_per_group
=
16
,
n
=
768
,
k
=
4096
,
num_groups
=
128
),
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4
ShapeArg
(
expected_m_per_group
=
16
,
n
=
4096
,
k
=
384
,
num_groups
=
128
),
]
args
=
parser
.
parse_args
()
benchmark_one_shape
(
shape_args
,
args
.
num_warmup
,
args
.
num_run
)
...
...
sgl-kernel/benchmark/bench_fp8_gemm.py
View file @
11965b0d
import
argparse
import
copy
import
itertools
import
os
from
typing
import
Optional
,
Tuple
import
torch
import
triton
from
sgl_kernel
import
fp8_scaled_mm
as
sgl_scaled_mm
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_fp8_quant
as
vllm_scaled_fp8_quant
# Optional vLLM import
try
:
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_fp8_quant
as
vllm_scaled_fp8_quant
VLLM_AVAILABLE
=
True
except
ImportError
:
vllm_scaled_mm
=
None
vllm_scaled_fp8_quant
=
None
VLLM_AVAILABLE
=
False
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
...
...
@@ -86,25 +102,48 @@ def sglang_scaled_fp8_quant(
return
output
,
scale
# CI environment uses simplified parameters
if
IS_CI
:
batch_sizes
=
[
1
]
# Single batch size for CI
else
:
batch_sizes
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
]
# Filter line_vals based on vLLM availability
if
VLLM_AVAILABLE
:
line_vals
=
[
"vllm-fp8-fp16"
,
"vllm-fp8-bf16"
,
"sglang-fp8-fp16"
,
"sglang-fp8-bf16"
,
]
line_names
=
[
"vllm-fp8-fp16"
,
"vllm-fp8-bf16"
,
"sglang-fp8-fp16"
,
"sglang-fp8-bf16"
,
]
styles
=
[(
"green"
,
"-"
),
(
"green"
,
"--"
),
(
"blue"
,
"-"
),
(
"blue"
,
"--"
)]
else
:
line_vals
=
[
"sglang-fp8-fp16"
,
"sglang-fp8-bf16"
,
]
line_names
=
[
"sglang-fp8-fp16"
,
"sglang-fp8-bf16"
,
]
styles
=
[(
"blue"
,
"-"
),
(
"blue"
,
"--"
)]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
]
,
x_vals
=
batch_sizes
,
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm-fp8-fp16"
,
"vllm-fp8-bf16"
,
"sglang-fp8-fp16"
,
"sglang-fp8-bf16"
,
],
line_names
=
[
"vllm-fp8-fp16"
,
"vllm-fp8-bf16"
,
"sglang-fp8-fp16"
,
"sglang-fp8-bf16"
,
],
styles
=
[(
"green"
,
"-"
),
(
"green"
,
"--"
),
(
"blue"
,
"-"
),
(
"blue"
,
"--"
)],
line_vals
=
line_vals
,
line_names
=
line_names
,
styles
=
styles
,
ylabel
=
"GB/s"
,
plot_name
=
"fp8 scaled matmul"
,
args
=
{},
...
...
@@ -115,6 +154,9 @@ def benchmark(batch_size, provider, N, K):
M
=
batch_size
a
=
torch
.
ones
((
M
,
K
),
device
=
"cuda"
)
*
5.0
b
=
torch
.
ones
((
N
,
K
),
device
=
"cuda"
)
*
5.0
# vLLM expects scalar scales, while sglang can handle per-token scales
scale_a_scalar
=
torch
.
randn
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b_scalar
=
torch
.
randn
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_a
=
torch
.
randn
((
M
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
...
...
@@ -122,8 +164,11 @@ def benchmark(batch_size, provider, N, K):
dtype
=
torch
.
float16
if
"fp16"
in
provider
else
torch
.
bfloat16
if
"vllm-fp8"
in
provider
:
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
if
not
VLLM_AVAILABLE
:
# Return zero if vLLM is not available
return
(
0
,
0
,
0
)
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a_scalar
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b_scalar
)
b_fp8
=
b_fp8
.
t
()
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
),
...
...
@@ -174,6 +219,11 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
# Simplify for CI environment
if
IS_CI
:
args
.
models
=
[
args
.
models
[
0
]]
# Use only first model
args
.
tp_sizes
=
[
args
.
tp_sizes
[
0
]]
# Use only first TP size
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
...
...
sgl-kernel/benchmark/bench_int8_gemm.py
View file @
11965b0d
import
argparse
import
copy
import
itertools
import
os
import
torch
import
triton
from
sgl_kernel
import
int8_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
# Optional vLLM import
try
:
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
VLLM_AVAILABLE
=
True
except
ImportError
:
vllm_scaled_mm
=
None
VLLM_AVAILABLE
=
False
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
def
to_int8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -62,15 +77,32 @@ WEIGHT_SHAPES = {
}
# CI environment uses simplified parameters
if
IS_CI
:
batch_sizes
=
[
1
]
# Single batch size for CI
else
:
batch_sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
]
# Filter providers based on vLLM availability
if
VLLM_AVAILABLE
:
line_vals
=
[
"vllm"
,
"sgl-kernel"
]
line_names
=
[
"vllm int8 gemm"
,
"sgl-kernel int8 gemm"
]
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)]
else
:
line_vals
=
[
"sgl-kernel"
]
line_names
=
[
"sgl-kernel int8 gemm"
]
styles
=
[(
"orange"
,
"-"
)]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
]
,
x_vals
=
batch_sizes
,
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"sgl-kernel"
]
,
line_names
=
[
"vllm int8 gemm"
,
"sgl-kernel int8 gemm"
]
,
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)]
,
line_vals
=
line_vals
,
line_names
=
line_names
,
styles
=
styles
,
ylabel
=
"GB/s"
,
plot_name
=
"int8 scaled matmul"
,
args
=
{},
...
...
@@ -90,7 +122,9 @@ def benchmark(batch_size, provider, N, K):
lambda
:
int8_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
),
quantiles
=
quantiles
,
)
if
provider
==
"vllm"
:
elif
provider
==
"vllm"
:
if
not
VLLM_AVAILABLE
:
return
(
0
,
0
,
0
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
vllm_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
),
quantiles
=
quantiles
,
...
...
@@ -136,9 +170,16 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
N
=
N
,
K
=
K
)
# Skip in CI environment due to architecture compatibility issues
if
IS_CI
:
print
(
"Skipping INT8 GEMM benchmark in CI environment due to architecture compatibility issues"
)
print
(
"INT8 operations may not be supported on all GPU architectures"
)
else
:
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
N
=
N
,
K
=
K
)
print
(
"Benchmark finished!"
)
print
(
"Benchmark finished!"
)
sgl-kernel/benchmark/bench_lightning_attention_decode.py
View file @
11965b0d
import
itertools
import
math
import
os
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
lightning_attention_decode
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
def
next_power_of_2
(
n
):
return
2
**
(
int
(
math
.
ceil
(
math
.
log
(
n
,
2
))))
...
...
@@ -207,7 +214,12 @@ def calculate_diff(batch_size):
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
i
for
i
in
range
(
1
,
65
)]
# 1 to 128
# Simplified for CI environment
if
IS_CI
:
batch_size_range
=
[
1
]
# Single batch size for CI
else
:
batch_size_range
=
[
i
for
i
in
range
(
1
,
65
)]
# 1 to 64
configs
=
[(
bs
,)
for
bs
in
batch_size_range
]
...
...
@@ -292,8 +304,9 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
# Run correctness test
calculate_diff
(
batch_size
=
4
)
# Run correctness test - simplified for CI
test_batch_size
=
1
if
IS_CI
else
4
calculate_diff
(
batch_size
=
test_batch_size
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/benchmark/bench_moe_align_block_size.py
View file @
11965b0d
import
argparse
import
itertools
import
os
import
torch
import
triton
...
...
@@ -8,8 +9,17 @@ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
try
:
from
vllm
import
_custom_ops
as
ops
VLLM_AVAILABLE
=
True
except
ImportError
:
ops
=
None
VLLM_AVAILABLE
=
False
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
USE_RANDOM_PERM
=
False
...
...
@@ -197,19 +207,23 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
num_tokens_post_pad_triton
,
)
try
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_vllm
,
expert_ids_vllm
,
num_tokens_post_pad_vllm
,
)
print
(
f
"✅ VLLM implementation works with
{
num_experts
}
experts!"
)
vllm_works
=
True
except
Exception
as
e
:
print
(
f
"❌ VLLM implementation failed with
{
num_experts
}
experts:
{
e
}
"
)
if
VLLM_AVAILABLE
:
try
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_vllm
,
expert_ids_vllm
,
num_tokens_post_pad_vllm
,
)
print
(
f
"✅ VLLM implementation works with
{
num_experts
}
experts!"
)
vllm_works
=
True
except
Exception
as
e
:
print
(
f
"❌ VLLM implementation failed with
{
num_experts
}
experts:
{
e
}
"
)
vllm_works
=
False
else
:
print
(
"⚠️ vLLM not available, skipping vLLM test"
)
vllm_works
=
False
if
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_triton
)
and
torch
.
allclose
(
...
...
@@ -394,8 +408,18 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
calculate_diff
(
num_tokens
=
1024
,
num_experts
=
args
.
num_experts
,
topk
=
args
.
topk
)
# Simplify for CI environment
if
IS_CI
:
num_tokens
=
256
# Smaller for CI
num_experts
=
8
# Smaller for CI
topk
=
2
# Smaller for CI
else
:
num_tokens
=
1024
num_experts
=
args
.
num_experts
topk
=
args
.
topk
calculate_diff
(
num_tokens
=
num_tokens
,
num_experts
=
num_experts
,
topk
=
topk
)
if
not
args
.
skip_full_benchmark
:
if
not
args
.
skip_full_benchmark
and
not
IS_CI
:
# Skip full benchmark in CI
print
(
f
"
\n
📊 Running performance benchmark for
{
args
.
num_experts
}
experts..."
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
View file @
11965b0d
import
os
import
torch
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
import
triton
from
sglang.srt.layers.moe.ep_moe.kernels
import
post_reorder_triton_kernel
batch_sizes
=
[
64
,
128
,
256
,
512
,
640
,
768
,
1024
,
2048
,
4096
]
# CI environment uses simplified parameters
if
IS_CI
:
batch_sizes
=
[
64
,
128
]
# Only test 2 values in CI
else
:
batch_sizes
=
[
64
,
128
,
256
,
512
,
640
,
768
,
1024
,
2048
,
4096
]
configs
=
[(
bs
,)
for
bs
in
batch_sizes
]
...
...
sgl-kernel/benchmark/bench_moe_fused_gate.py
View file @
11965b0d
import
itertools
import
math
import
os
import
torch
import
triton
...
...
@@ -8,6 +9,12 @@ from sgl_kernel import moe_fused_gate
from
sglang.srt.layers.moe.topk
import
biased_grouped_topk
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
def
biased_grouped_topk_org
(
scores
,
bias
,
num_expert_group
,
topk_group
,
topk
):
return
biased_grouped_topk
(
...
...
@@ -28,7 +35,12 @@ def biased_grouped_topk_org_fuse_kernel(
return
moe_fused_gate
(
scores
,
bias
,
num_expert_group
,
topk_group
,
topk
)
seq_length_range
=
[
5000
,
10000
,
15000
,
20000
,
25000
,
30000
,
35000
,
40000
]
# CI environment uses simplified parameters
if
IS_CI
:
seq_length_range
=
[
5000
]
# Only test one sequence length in CI
else
:
seq_length_range
=
[
5000
,
10000
,
15000
,
20000
,
25000
,
30000
,
35000
,
40000
]
configs
=
[(
sq
,)
for
sq
in
seq_length_range
]
...
...
sgl-kernel/benchmark/bench_moe_topk_softmax.py
View file @
11965b0d
import
itertools
import
os
import
pytest
import
torch
import
triton
from
sgl_kernel
import
topk_softmax
from
vllm
import
_custom_ops
as
vllm_custom_ops
# Optional vLLM import
try
:
from
vllm
import
_custom_ops
as
vllm_custom_ops
VLLM_AVAILABLE
=
True
except
ImportError
:
vllm_custom_ops
=
None
VLLM_AVAILABLE
=
False
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
def
vllm_topk_softmax
(
gating_output
,
topk
):
if
not
VLLM_AVAILABLE
:
# Fallback to SGLang implementation if vLLM is not available
return
sglang_topk_softmax
(
gating_output
,
topk
)
num_tokens
,
num_experts
=
gating_output
.
shape
topk_weights
=
torch
.
empty
(
...
...
@@ -54,6 +73,10 @@ def calculate_diff(num_tokens, num_experts, topk):
weights_diff
=
torch
.
abs
(
weights_vllm
-
weights_sglang
).
mean
().
item
()
indices_match
=
torch
.
equal
(
indices_vllm
,
indices_sglang
)
if
not
VLLM_AVAILABLE
:
print
(
"⚠️ vLLM not available, skipping comparison"
)
return
if
(
torch
.
allclose
(
weights_vllm
,
weights_sglang
,
atol
=
1e-3
,
rtol
=
1e-3
)
and
indices_match
...
...
@@ -65,21 +88,38 @@ def calculate_diff(num_tokens, num_experts, topk):
)
num_tokens_range
=
[
128
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
num_experts_range
=
[
32
,
64
,
128
,
256
,
12
,
512
]
topk_range
=
[
1
,
2
,
4
,
8
]
# CI environment uses simplified parameters
if
IS_CI
:
num_tokens_range
=
[
128
]
# Single value for CI
num_experts_range
=
[
32
]
# Single value for CI
topk_range
=
[
2
]
# Single value for CI
else
:
num_tokens_range
=
[
128
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
num_experts_range
=
[
32
,
64
,
128
,
256
,
12
,
512
]
topk_range
=
[
1
,
2
,
4
,
8
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
))
# Filter providers based on vLLM availability
if
VLLM_AVAILABLE
:
line_vals
=
[
"sglang"
,
"vllm"
]
line_names
=
[
"SGLang"
,
"VLLM"
]
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)]
else
:
line_vals
=
[
"sglang"
]
line_names
=
[
"SGLang"
]
styles
=
[(
"blue"
,
"-"
)]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"sglang"
,
"vllm"
]
,
line_names
=
[
"SGLang"
,
"VLLM"
]
,
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)]
,
line_vals
=
line_vals
,
line_names
=
line_names
,
styles
=
styles
,
ylabel
=
"Latency (us)"
,
plot_name
=
"topk-softmax-performance"
,
args
=
{},
...
...
@@ -92,6 +132,8 @@ def benchmark(num_tokens, num_experts, topk, provider):
)
if
provider
==
"vllm"
or
provider
==
"vllm1"
:
if
not
VLLM_AVAILABLE
:
return
(
0
,
0
,
0
)
fn
=
lambda
:
vllm_topk_softmax
(
gating_output
,
topk
)
elif
provider
==
"sglang"
or
provider
==
"sglang1"
:
fn
=
lambda
:
sglang_topk_softmax
(
gating_output
,
topk
)
...
...
@@ -103,14 +145,19 @@ def benchmark(num_tokens, num_experts, topk, provider):
if
__name__
==
"__main__"
:
configs
=
[
(
20
,
256
,
4
),
(
20
,
256
,
8
),
(
20
,
12
,
4
),
(
20
,
12
,
1
),
(
20
,
512
,
4
),
(
20
,
512
,
1
),
]
for
num_tokens
,
num_experts
,
topk
in
configs
:
# Simplify configs for CI environment
if
IS_CI
:
test_configs
=
[(
20
,
32
,
2
)]
# Single config for CI
else
:
test_configs
=
[
(
20
,
256
,
4
),
(
20
,
256
,
8
),
(
20
,
12
,
4
),
(
20
,
12
,
1
),
(
20
,
512
,
4
),
(
20
,
512
,
1
),
]
for
num_tokens
,
num_experts
,
topk
in
test_configs
:
calculate_diff
(
num_tokens
,
num_experts
,
topk
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py
View file @
11965b0d
import
argparse
import
copy
import
itertools
import
os
import
torch
import
triton
from
sgl_kernel
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
sglang.srt.utils
import
get_device_capability
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
FLOAT4_E2M1_MAX
=
6.0
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
...
...
@@ -162,9 +171,22 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
N
=
N
,
K
=
K
)
# Check architecture compatibility - FP4 operations require sm100a/sm103a
major
,
minor
=
get_device_capability
()
if
major
is
None
or
major
<
10
:
# Requires compute capability 10.0+ (sm100a/sm103a)
print
(
"Skipping NVIDIA FP4 scaled GEMM benchmark"
)
if
major
is
not
None
:
print
(
f
"FP4 operations require sm100a/sm103a, but found sm
{
major
}{
minor
}
"
)
else
:
print
(
"Could not determine device capability"
)
else
:
KN_model_names
=
prepare_shapes
(
args
)
# Limit iterations in CI
if
IS_CI
:
KN_model_names
=
KN_model_names
[:
2
]
# Only test first 2 shapes in CI
print
(
"Benchmark finished!"
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
N
=
N
,
K
=
K
)
print
(
"Benchmark finished!"
)
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
View file @
11965b0d
import
itertools
import
math
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
...
...
@@ -7,11 +8,26 @@ import torch
import
triton
import
triton.testing
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
from
vllm
import
_custom_ops
as
ops
# Optional imports
try
:
from
vllm
import
_custom_ops
as
ops
VLLM_AVAILABLE
=
True
except
ImportError
:
ops
=
None
VLLM_AVAILABLE
=
False
from
sglang.srt.utils
import
is_hip
_is_hip
=
is_hip
()
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
...
...
@@ -19,6 +35,9 @@ def vllm_scaled_fp8_quant(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
VLLM_AVAILABLE
:
# Fallback to SGLang implementation
return
sglang_scaled_fp8_quant
(
input
,
scale
)
return
ops
.
scaled_fp8_quant
(
input
,
scale
)
...
...
@@ -42,6 +61,10 @@ def calculate_diff(batch_size: int, seq_len: int):
device
=
torch
.
device
(
"cuda"
)
x
=
torch
.
rand
((
batch_size
,
seq_len
),
dtype
=
torch
.
float16
,
device
=
device
)
if
not
VLLM_AVAILABLE
:
print
(
"⚠️ vLLM not available, skipping comparison"
)
return
vllm_out
,
vllm_scale
=
vllm_scaled_fp8_quant
(
x
)
sglang_out
,
sglang_scale
=
sglang_scaled_fp8_quant
(
x
)
...
...
@@ -56,8 +79,13 @@ def calculate_diff(batch_size: int, seq_len: int):
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
16
,
32
,
64
,
128
]
seq_len_range
=
[
64
,
128
,
256
,
512
,
1024
,
2048
]
# CI environment uses simplified parameters
if
IS_CI
:
batch_size_range
=
[
16
]
# Single batch size for CI
seq_len_range
=
[
64
]
# Single sequence length for CI
else
:
batch_size_range
=
[
16
,
32
,
64
,
128
]
seq_len_range
=
[
64
,
128
,
256
,
512
,
1024
,
2048
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
))
...
...
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
View file @
11965b0d
import
itertools
import
os
import
time
from
functools
import
partial
from
pathlib
import
Path
...
...
@@ -16,15 +17,28 @@ from sglang.srt.layers.quantization.fp8_kernel import (
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_8bit
from
sglang.srt.utils
import
is_hip
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
num_tokens_range
=
[
1
,
4
,
16
,
64
,
256
,
768
,
2048
,
8192
,
16384
]
hidden_dim_range
=
[
1536
,
7168
,
18432
]
# For DeepSeek V3/R1
group_size_range
=
[
128
]
# For DeepSeek V3/R1
# TODO test int8
dst_dtype_range
=
[
fp8_type_
]
# CI environment uses simplified parameters
if
IS_CI
:
num_tokens_range
=
[
64
]
# Single value for CI
hidden_dim_range
=
[
1536
]
# Single value for CI
group_size_range
=
[
128
]
# Keep as is
dst_dtype_range
=
[
fp8_type_
]
# Keep as is
else
:
num_tokens_range
=
[
1
,
4
,
16
,
64
,
256
,
768
,
2048
,
8192
,
16384
]
hidden_dim_range
=
[
1536
,
7168
,
18432
]
# For DeepSeek V3/R1
group_size_range
=
[
128
]
# For DeepSeek V3/R1
# TODO test int8
dst_dtype_range
=
[
fp8_type_
]
flags_range
=
[
dict
(
column_major_scales
=
False
,
...
...
@@ -82,7 +96,7 @@ def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
fn
,
kernel_names
=
{
"triton"
:
(
triton_per_token_group_quant_8bit
,
"_per_token_group_quant_
fp8
"
),
"triton"
:
(
triton_per_token_group_quant_8bit
,
"_per_token_group_quant_
8bit
"
),
"sglang"
:
(
sglang_per_token_group_quant_8bit
,
"per_token_group_quant_8bit_kernel"
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment