Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3d330c4c
Unverified
Commit
3d330c4c
authored
Jun 15, 2025
by
Wentao Ye
Committed by
GitHub
Jun 15, 2025
Browse files
[Benchmark] Refactor benchmark script for fp8 & int8 (#19627)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
0b73736a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
184 additions
and
280 deletions
+184
-280
benchmarks/kernels/bench_fp8_gemm.py
benchmarks/kernels/bench_fp8_gemm.py
+92
-157
benchmarks/kernels/bench_int8_gemm.py
benchmarks/kernels/bench_int8_gemm.py
+92
-123
No files found.
benchmarks/kernels/bench_fp8_gemm.py
View file @
3d330c4c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
...
...
@@ -11,6 +10,80 @@ from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from
vllm._custom_ops
import
scaled_fp8_quant
as
vllm_scaled_fp8_quant
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"fp8-tensor-w-token-a"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
False
),
"fp8-tensor-w-tensor-a"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
True
),
"fp8-channel-w-token-a"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
True
),
"fp8-channel-w-tensor-a"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
False
),
"fp8-tensor-w-token-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
False
),
"fp8-tensor-w-tensor-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
True
),
"fp8-channel-w-token-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
True
),
"fp8-channel-w-tensor-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
False
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
_quant_weight_fp8
(
b
:
torch
.
Tensor
,
w_type
:
str
,
device
:
str
):
if
w_type
==
"tensor"
:
scale_b
=
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
else
:
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
use_per_token_if_dynamic
=
True
)
return
b_fp8
.
t
(),
scale_b_fp8
def
build_fp8_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
b_fp8
,
scale_b_fp8
=
_quant_weight_fp8
(
b
,
cfg
[
"w"
],
device
)
scale_a_const
=
(
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
if
cfg
[
"a"
]
==
"tensor"
else
None
)
if
cfg
[
"no_a_quant"
]:
if
cfg
[
"a"
]
==
"tensor"
:
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a_const
)
else
:
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
def
run
():
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
return
run
if
cfg
[
"a"
]
==
"tensor"
:
def
run
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a_const
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
else
:
def
run
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
...
...
@@ -18,28 +91,8 @@ from vllm.triton_utils import triton
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"torch-bf16"
,
# "fp8-tensor-w-token-a",
"fp8-tensor-w-tensor-a"
,
"fp8-channel-w-token-a"
,
# "fp8-channel-w-tensor-a",
# "fp8-tensor-w-token-a-noquant",
"fp8-tensor-w-tensor-a-noquant"
,
"fp8-channel-w-token-a-noquant"
,
# "fp8-channel-w-tensor-a-noquant",
],
line_names
=
[
"torch-bf16"
,
# "fp8-tensor-w-token-a",
"fp8-tensor-w-tensor-a"
,
"fp8-channel-w-token-a"
,
# "fp8-channel-w-tensor-a",
# "fp8-tensor-w-token-a-noquant",
"fp8-tensor-w-tensor-a-noquant"
,
"fp8-channel-w-token-a-noquant"
,
# "fp8-channel-w-tensor-a-noquant",
],
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs FP8 GEMMs"
,
args
=
{},
...
...
@@ -50,144 +103,34 @@ def benchmark(batch_size, provider, N, K):
device
=
"cuda"
dtype
=
torch
.
bfloat16
# Create input tensors
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
"torch-bf16"
in
provider
:
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
elif
"fp8"
in
provider
:
# Weights are always quantized ahead of time
if
"noquant"
in
provider
:
# For no quantization, we just measure the GEMM
if
"tensor-w-token-a"
in
provider
:
# Dynamic per-token quant for A, per-tensor quant for B
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
)
assert
scale_b_fp8
.
numel
()
==
1
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
def
run_quant
():
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
elif
"tensor-w-tensor-a"
in
provider
:
# Static per-tensor quantization with fixed scales
# for both A and B
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
assert
scale_b_fp8
.
numel
()
==
1
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
def
run_quant
():
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
elif
"channel-w-token-a"
in
provider
:
# Static per-channel quantization for weights, per-token
# quant for A
scale_b
=
torch
.
tensor
((
N
,),
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
scale_b_fp8
=
scale_b_fp8
.
expand
(
N
).
contiguous
()
assert
scale_b_fp8
.
numel
()
==
N
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
def
run_quant
():
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
elif
"channel-w-tensor-a"
in
provider
:
# Static per-channel quantization for weights, per-tensor
# quant for A
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
((
N
,),
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
scale_b_fp8
=
scale_b_fp8
.
expand
(
N
).
contiguous
()
assert
scale_b_fp8
.
numel
()
==
N
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
def
run_quant
():
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
else
:
# In these cases, we quantize the activations during the GEMM call
if
"tensor-w-token-a"
in
provider
:
# Dynamic per-token quant for A, per-tensor quant for B
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
)
assert
scale_b_fp8
.
numel
()
==
1
def
run_quant
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
elif
"tensor-w-tensor-a"
in
provider
:
# Static per-tensor quantization with fixed scales
# for both A and B
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
assert
scale_b_fp8
.
numel
()
==
1
def
run_quant
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
elif
"channel-w-token-a"
in
provider
:
# Static per-channel quantization for weights, per-token
# quant for A
scale_b
=
torch
.
tensor
((
N
,),
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
scale_b_fp8
=
scale_b_fp8
.
expand
(
N
).
contiguous
()
assert
scale_b_fp8
.
numel
()
==
N
def
run_quant
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
elif
"channel-w-tensor-a"
in
provider
:
# Static per-channel quantization for weights, per-tensor
# quant for A
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
((
N
,),
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
scale_b_fp8
=
scale_b_fp8
.
expand
(
N
).
contiguous
()
assert
scale_b_fp8
.
numel
()
==
N
def
run_quant
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
b_fp8
=
b_fp8
.
t
()
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_fp8_runner
(
cfg
,
a
,
b
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
# Calculate TFLOP/s, two flops per multiply-add
tflops
=
lambda
ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
ms
*
1e-3
)
return
tflops
(
ms
),
tflops
(
max_ms
),
tflops
(
min_ms
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
KN_model_names
=
[]
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
for
model
,
tp_size
in
models_tps
:
assert
model
in
WEIGHT_SHAPES
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
KN_model_names
.
append
(
KN
)
return
KN_model_names
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
...
...
@@ -197,21 +140,13 @@ if __name__ == "__main__":
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
[
*
WEIGHT_SHAPES
.
keys
()],
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
, N=
{
N
}
K=
{
K
}
, BF16 vs FP8 GEMMs TFLOP/s:"
)
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs FP8 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
...
...
benchmarks/kernels/bench_int8_gemm.py
View file @
3d330c4c
...
...
@@ -11,6 +11,84 @@ from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from
vllm._custom_ops
import
scaled_int8_quant
as
vllm_scaled_int8_quant
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"int8-tensor-w-token-a"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
False
),
"int8-tensor-w-tensor-a"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
True
),
"int8-channel-w-token-a"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
True
),
"int8-channel-w-tensor-a"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
False
),
"int8-tensor-w-token-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
False
),
"int8-tensor-w-tensor-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
True
),
"int8-channel-w-token-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
True
),
"int8-channel-w-tensor-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
False
),
}
def
_quant_weight
(
b
,
w_type
,
device
):
if
w_type
==
"tensor"
:
scale_b
=
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
,
scale_b
)
assert
scale_b_int8
.
numel
()
==
1
else
:
# channel
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
)
assert
scale_b_int8
.
numel
()
==
b
.
shape
[
0
]
return
b_int8
.
t
(),
scale_b_int8
def
build_int8_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
# quant before running the kernel
b_int8
,
scale_b_int8
=
_quant_weight
(
b
,
cfg
[
"w"
],
device
)
scale_a_const
=
None
if
cfg
[
"a"
]
==
"tensor"
:
scale_a_const
=
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
# no quant, create activation ahead
if
cfg
[
"no_a_quant"
]:
if
cfg
[
"a"
]
==
"tensor"
:
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a_const
)
else
:
# token
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
def
run_quant
():
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
return
run_quant
# dynamic quant, create activation inside
if
cfg
[
"a"
]
==
"tensor"
:
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a_const
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
else
:
# token
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
return
run_quant
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
.
get
(
"enabled"
)]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
...
...
@@ -18,28 +96,8 @@ from vllm.triton_utils import triton
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"torch-bf16"
,
# "int8-tensor-w-token-a",
"int8-tensor-w-tensor-a"
,
"int8-channel-w-token-a"
,
# "int8-channel-w-tensor-a",
# "int8-tensor-w-token-a-noquant",
"int8-tensor-w-tensor-a-noquant"
,
"int8-channel-w-token-a-noquant"
,
# "int8-channel-w-tensor-a-noquant",
],
line_names
=
[
"torch-bf16"
,
# "int8-tensor-w-token-a",
"int8-tensor-w-tensor-a"
,
"int8-channel-w-token-a"
,
# "int8-channel-w-tensor-a",
# "int8-tensor-w-token-a-noquant",
"int8-tensor-w-tensor-a-noquant"
,
"int8-channel-w-token-a-noquant"
,
# "int8-channel-w-tensor-a-noquant",
],
line_vals
=
_enabled
,
line_names
=
[
k
for
k
in
_enabled
],
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs INT8 GEMMs"
,
args
=
{},
...
...
@@ -54,114 +112,26 @@ def benchmark(batch_size, provider, N, K):
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
"torch-bf16"
in
provider
:
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
elif
"int8"
in
provider
:
# Weights are always quantized ahead of time
if
"noquant"
in
provider
:
# For "no quant", we don't measure the time for activations
if
"tensor-w-token-a"
in
provider
:
# Dynamic per-token quant for A, static per-tensor quant for B
scale_b
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
,
scale_b
)
assert
scale_b_int8
.
numel
()
==
1
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
elif
"tensor-w-tensor-a"
in
provider
:
# Static per-tensor quantization with fixed scales for both A and B
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
,
scale_b
)
assert
scale_b_int8
.
numel
()
==
1
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a
)
elif
"channel-w-token-a"
in
provider
:
# Dynamic per-channel quantization for weights, per-token quant for A
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
)
assert
scale_b_int8
.
numel
()
==
N
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
elif
"channel-w-tensor-a"
in
provider
:
# Dynamic per-channel quantization for weights, per-tensor quant for A
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
)
assert
scale_b_int8
.
numel
()
==
N
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a
)
def
run_quant
():
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
else
:
# Quantize the activations during the GEMM call
if
"tensor-w-token-a"
in
provider
:
# Dynamic per-token quant for A, static per-tensor quant for B
scale_b
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
,
scale_b
)
assert
scale_b_int8
.
numel
()
==
1
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
elif
"tensor-w-tensor-a"
in
provider
:
# Static per-tensor quantization with fixed scales for both A and B
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
,
scale_b
)
assert
scale_b_int8
.
numel
()
==
1
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
elif
"channel-w-token-a"
in
provider
:
# Dynamic per-channel quant for weights, per-token quant for A
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
)
assert
scale_b_int8
.
numel
()
==
N
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
elif
"channel-w-tensor-a"
in
provider
:
# Dynamic per-channel quant for weights, static per-tensor quant for A
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
)
assert
scale_b_int8
.
numel
()
==
N
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
b_int8
=
b_int8
.
t
()
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_int8_runner
(
cfg
,
a
,
b
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
# Calculate TFLOP/s, two flops per multiply-add
tflops
=
lambda
ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
ms
*
1e-3
)
return
tflops
(
ms
),
tflops
(
max_ms
),
tflops
(
min_ms
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
KN_model_names
=
[]
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
for
model
,
tp_size
in
models_tps
:
assert
model
in
WEIGHT_SHAPES
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
KN_model_names
.
append
(
KN
)
return
KN_model_names
...
...
@@ -174,7 +144,7 @@ if __name__ == "__main__":
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
[
*
WEIGHT_SHAPES
.
keys
()
]
,
choices
=
list
(
WEIGHT_SHAPES
.
keys
()
)
,
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
...
...
@@ -186,9 +156,8 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
, N=
{
N
}
K=
{
K
}
, BF16 vs INT8 GEMMs TFLOP/s:"
)
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs INT8 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment