Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f49239cb
Unverified
Commit
f49239cb
authored
May 30, 2025
by
Michael Goin
Committed by
GitHub
May 30, 2025
Browse files
Benchmark script for fp8 vs bf16 gemm (#17126)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
2dbe8c07
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
268 additions
and
0 deletions
+268
-0
benchmarks/kernels/bench_fp8_gemm.py
benchmarks/kernels/bench_fp8_gemm.py
+222
-0
benchmarks/kernels/weight_shapes.py
benchmarks/kernels/weight_shapes.py
+46
-0
No files found.
benchmarks/kernels/bench_fp8_gemm.py
0 → 100644
View file @
f49239cb
# SPDX-License-Identifier: Apache-2.0
import
argparse
import
copy
import
itertools
import
torch
import
triton
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_fp8_quant
as
vllm_scaled_fp8_quant
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"torch-bf16"
,
# "fp8-tensor-w-token-a",
"fp8-tensor-w-tensor-a"
,
"fp8-channel-w-token-a"
,
# "fp8-channel-w-tensor-a",
# "fp8-tensor-w-token-a-noquant",
"fp8-tensor-w-tensor-a-noquant"
,
"fp8-channel-w-token-a-noquant"
,
# "fp8-channel-w-tensor-a-noquant",
],
line_names
=
[
"torch-bf16"
,
# "fp8-tensor-w-token-a",
"fp8-tensor-w-tensor-a"
,
"fp8-channel-w-token-a"
,
# "fp8-channel-w-tensor-a",
# "fp8-tensor-w-token-a-noquant",
"fp8-tensor-w-tensor-a-noquant"
,
"fp8-channel-w-token-a-noquant"
,
# "fp8-channel-w-tensor-a-noquant",
],
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs FP8 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
# Create input tensors
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
"torch-bf16"
in
provider
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
elif
"fp8"
in
provider
:
# Weights are always quantized ahead of time
if
"noquant"
in
provider
:
# For no quantization, we just measure the GEMM
if
"tensor-w-token-a"
in
provider
:
# Dynamic per-token quant for A, per-tensor quant for B
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
)
assert
scale_b_fp8
.
numel
()
==
1
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
def
run_quant
():
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
elif
"tensor-w-tensor-a"
in
provider
:
# Static per-tensor quantization with fixed scales
# for both A and B
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
assert
scale_b_fp8
.
numel
()
==
1
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
def
run_quant
():
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
elif
"channel-w-token-a"
in
provider
:
# Static per-channel quantization for weights, per-token
# quant for A
scale_b
=
torch
.
tensor
((
N
,),
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
scale_b_fp8
=
scale_b_fp8
.
expand
(
N
).
contiguous
()
assert
scale_b_fp8
.
numel
()
==
N
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
def
run_quant
():
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
elif
"channel-w-tensor-a"
in
provider
:
# Static per-channel quantization for weights, per-tensor
# quant for A
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
((
N
,),
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
scale_b_fp8
=
scale_b_fp8
.
expand
(
N
).
contiguous
()
assert
scale_b_fp8
.
numel
()
==
N
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
def
run_quant
():
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
else
:
# In these cases, we quantize the activations during the GEMM call
if
"tensor-w-token-a"
in
provider
:
# Dynamic per-token quant for A, per-tensor quant for B
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
)
assert
scale_b_fp8
.
numel
()
==
1
def
run_quant
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
elif
"tensor-w-tensor-a"
in
provider
:
# Static per-tensor quantization with fixed scales
# for both A and B
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
assert
scale_b_fp8
.
numel
()
==
1
def
run_quant
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
elif
"channel-w-token-a"
in
provider
:
# Static per-channel quantization for weights, per-token
# quant for A
scale_b
=
torch
.
tensor
((
N
,),
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
scale_b_fp8
=
scale_b_fp8
.
expand
(
N
).
contiguous
()
assert
scale_b_fp8
.
numel
()
==
N
def
run_quant
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
elif
"channel-w-tensor-a"
in
provider
:
# Static per-channel quantization for weights, per-tensor
# quant for A
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
((
N
,),
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
scale_b_fp8
=
scale_b_fp8
.
expand
(
N
).
contiguous
()
assert
scale_b_fp8
.
numel
()
==
N
def
run_quant
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
b_fp8
=
b_fp8
.
t
()
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
# Calculate TFLOP/s, two flops per multiply-add
tflops
=
lambda
ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
ms
*
1e-3
)
return
tflops
(
ms
),
tflops
(
max_ms
),
tflops
(
min_ms
)
def
prepare_shapes
(
args
):
KN_model_names
=
[]
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
for
model
,
tp_size
in
models_tps
:
assert
model
in
WEIGHT_SHAPES
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
KN
.
append
(
model
)
KN_model_names
.
append
(
KN
)
return
KN_model_names
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
[
*
WEIGHT_SHAPES
.
keys
()],
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
)
args
=
parser
.
parse_args
()
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
, N=
{
N
}
K=
{
K
}
, BF16 vs FP8 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_fp8_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
benchmarks/kernels/weight_shapes.py
View file @
f49239cb
...
...
@@ -48,4 +48,50 @@ WEIGHT_SHAPES = {
([
16384
,
106496
],
1
),
([
53248
,
16384
],
0
),
],
"meta-llama/Llama-3.1-8B-Instruct"
:
[
([
4096
,
6144
],
1
),
([
4096
,
4096
],
0
),
([
4096
,
28672
],
1
),
([
14336
,
4096
],
0
),
],
"meta-llama/Llama-3.3-70B-Instruct"
:
[
([
8192
,
10240
],
1
),
([
8192
,
8192
],
0
),
([
8192
,
57344
],
1
),
([
28672
,
8192
],
0
),
],
"mistralai/Mistral-Large-Instruct-2407"
:
[
([
12288
,
14336
],
1
),
([
12288
,
12288
],
0
),
([
12288
,
57344
],
1
),
([
28672
,
12288
],
0
),
],
"Qwen/Qwen2.5-7B-Instruct"
:
[
([
3584
,
4608
],
1
),
([
3584
,
3584
],
0
),
([
3584
,
37888
],
1
),
([
18944
,
3584
],
0
),
],
"Qwen/Qwen2.5-32B-Instruct"
:
[
([
5120
,
7168
],
1
),
([
5120
,
5120
],
0
),
([
5120
,
55296
],
1
),
([
27648
,
5120
],
0
),
],
"Qwen/Qwen2.5-72B-Instruct"
:
[
([
8192
,
10240
],
1
),
([
8192
,
8192
],
0
),
([
8192
,
59136
],
1
),
([
29568
,
8192
],
0
),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
:
[
([
2048
,
3072
],
1
),
([
2048
,
4096
],
1
),
([
2048
,
2048
],
0
),
([
2048
,
576
],
0
),
([
2048
,
21888
],
1
),
([
10944
,
2048
],
0
),
([
2048
,
2816
],
1
),
([
1408
,
2048
],
0
),
],
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment