Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0bbac1c1
Unverified
Commit
0bbac1c1
authored
Jul 10, 2025
by
Michael Goin
Committed by
GitHub
Jul 09, 2025
Browse files
[Bench] Add NVFP4 GEMM benchmark script (#20578)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
a3e4e85e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
141 additions
and
0 deletions
+141
-0
benchmarks/kernels/bench_nvfp4_gemm.py
benchmarks/kernels/bench_nvfp4_gemm.py
+141
-0
No files found.
benchmarks/kernels/bench_nvfp4_gemm.py
0 → 100644
View file @
0bbac1c1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.triton_utils
import
triton
if
not
current_platform
.
has_device_capability
(
100
):
raise
RuntimeError
(
"NVFP4 requires compute capability of 10.0 (Blackwell)"
)
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"nvfp4"
:
dict
(
no_a_quant
=
False
,
enabled
=
True
),
"nvfp4-noquant"
:
dict
(
no_a_quant
=
True
,
enabled
=
True
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
_quant_weight_nvfp4
(
b
:
torch
.
Tensor
,
device
:
str
):
# Compute global scale for weight
b_amax
=
torch
.
abs
(
b
).
max
().
to
(
torch
.
float32
)
b_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
b_amax
b_fp4
,
scale_b_fp4
=
ops
.
scaled_fp4_quant
(
b
,
b_global_scale
)
return
b_fp4
,
scale_b_fp4
,
b_global_scale
def
build_nvfp4_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
b_fp4
,
scale_b_fp4
,
b_global_scale
=
_quant_weight_nvfp4
(
b
,
device
)
# Compute global scale for activation
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
a_amax
=
torch
.
abs
(
a
).
max
().
to
(
torch
.
float32
)
a_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
a_amax
# Alpha for the GEMM operation
alpha
=
1.0
/
(
a_global_scale
*
b_global_scale
)
if
cfg
[
"no_a_quant"
]:
# Pre-quantize activation
a_fp4
,
scale_a_fp4
=
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
)
def
run
():
return
ops
.
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
scale_a_fp4
,
scale_b_fp4
,
alpha
,
dtype
)
return
run
# Quantize activation on-the-fly
def
run
():
a_fp4
,
scale_a_fp4
=
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
)
return
ops
.
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
scale_a_fp4
,
scale_b_fp4
,
alpha
,
dtype
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs NVFP4 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_nvfp4_runner
(
cfg
,
a
,
b
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs NVFP4 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_nvfp4_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment