Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
426b7493
"vscode:/vscode.git/clone" did not exist on "2e7633982ce51ffd360ca8932cd0f1254daa4e11"
Unverified
Commit
426b7493
authored
Jul 27, 2025
by
Qi Yuhang
Committed by
GitHub
Jul 26, 2025
Browse files
Add nvfp4 scaled mm benchmark. (#8401)
parent
528bd1ed
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
172 additions
and
0 deletions
+172
-0
sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py
sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py
+172
-0
No files found.
sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py
0 → 100644
View file @
426b7493
import
argparse
import
copy
import
itertools
import
torch
import
triton
from
sgl_kernel
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
FLOAT4_E2M1_MAX
=
6.0
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES
=
{
"meta-llama/Llama-3.1-8B-Instruct"
:
[
([
4096
,
6144
],
1
),
([
4096
,
4096
],
0
),
([
4096
,
28672
],
1
),
([
14336
,
4096
],
0
),
],
"meta-llama/Llama-3.3-70B-Instruct"
:
[
([
8192
,
10240
],
1
),
([
8192
,
8192
],
0
),
([
8192
,
57344
],
1
),
([
28672
,
8192
],
0
),
],
"mistralai/Mistral-Large-Instruct-2407"
:
[
([
12288
,
14336
],
1
),
([
12288
,
12288
],
0
),
([
12288
,
57344
],
1
),
([
28672
,
12288
],
0
),
],
"Qwen/Qwen2.5-7B-Instruct"
:
[
([
3584
,
4608
],
1
),
([
3584
,
3584
],
0
),
([
3584
,
37888
],
1
),
([
18944
,
3584
],
0
),
],
"Qwen/Qwen2.5-32B-Instruct"
:
[
([
5120
,
7168
],
1
),
([
5120
,
5120
],
0
),
([
5120
,
55296
],
1
),
([
27648
,
5120
],
0
),
],
"Qwen/Qwen2.5-72B-Instruct"
:
[
([
8192
,
10240
],
1
),
([
8192
,
8192
],
0
),
([
8192
,
59136
],
1
),
([
29568
,
8192
],
0
),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
:
[
([
2048
,
3072
],
1
),
([
2048
,
4096
],
1
),
([
2048
,
2048
],
0
),
([
2048
,
576
],
0
),
([
2048
,
21888
],
1
),
([
10944
,
2048
],
0
),
([
2048
,
2816
],
1
),
([
1408
,
2048
],
0
),
],
}
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"sglang-fp4-fp16"
,
"sglang-fp4-bf16"
,
],
line_names
=
[
"sglang-fp4-fp16"
,
"sglang-fp4-bf16"
,
],
styles
=
[(
"green"
,
"-"
),
(
"blue"
,
"-"
)],
ylabel
=
"TFLOPS"
,
plot_name
=
"fp4 block scaled matmul"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
# M, N, K = batch_size, 4096, 8192
run_step
=
100
dtype
=
torch
.
float16
if
"fp16"
in
provider
else
torch
.
bfloat16
M
=
batch_size
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
,
device
=
"cuda"
)
b
=
torch
.
randn
((
N
,
K
),
dtype
=
dtype
,
device
=
"cuda"
)
a_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
b_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
b
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
alpha
=
1.0
/
(
a_global_scale
*
b_global_scale
)
a_fp4
,
a_scale_interleaved
=
scaled_fp4_quant
(
a
,
a_global_scale
)
b_fp4
,
b_scale_interleaved
=
scaled_fp4_quant
(
b
,
b_global_scale
)
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# Bridging the gap between CPU and GPU
for
_
in
range
(
25
):
c
=
a
@
b
.
t
()
# Warmup
for
_
in
range
(
5
):
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
dtype
)
start_event
.
record
()
for
_
in
range
(
run_step
):
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
dtype
)
end_event
.
record
()
end_event
.
synchronize
()
torch
.
cuda
.
synchronize
()
ms
=
start_event
.
elapsed_time
(
end_event
)
/
run_step
tflops
=
lambda
ms
:
(
2
*
M
*
N
*
K
)
*
1e-9
/
ms
return
tflops
(
ms
)
def
prepare_shapes
(
args
):
KN_model_names
=
[]
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
for
model
,
tp_size
in
models_tps
:
assert
model
in
WEIGHT_SHAPES
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
KN
.
append
(
model
)
KN_model_names
.
append
(
KN
)
return
KN_model_names
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
)
args
=
parser
.
parse_args
()
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_fp4_res"
,
N
=
N
,
K
=
K
)
print
(
"Benchmark finished!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment