Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
raojy
vllm_017
Commits
fbeb8a6f
Commit
fbeb8a6f
authored
Mar 27, 2026
by
raojy
Browse files
raw_vllm
parent
2ca8867f
Pipeline
#3454
canceled with stages
Changes
165
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4936 additions
and
0 deletions
+4936
-0
benchmarks/kernels/benchmark_mrope.py
benchmarks/kernels/benchmark_mrope.py
+324
-0
benchmarks/kernels/benchmark_mxfp4_qutlass.py
benchmarks/kernels/benchmark_mxfp4_qutlass.py
+191
-0
benchmarks/kernels/benchmark_nvfp4_gemm.py
benchmarks/kernels/benchmark_nvfp4_gemm.py
+198
-0
benchmarks/kernels/benchmark_nvfp4_quant.py
benchmarks/kernels/benchmark_nvfp4_quant.py
+210
-0
benchmarks/kernels/benchmark_nvfp4_qutlass.py
benchmarks/kernels/benchmark_nvfp4_qutlass.py
+207
-0
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+251
-0
benchmarks/kernels/benchmark_per_token_group_quant.py
benchmarks/kernels/benchmark_per_token_group_quant.py
+159
-0
benchmarks/kernels/benchmark_per_token_quant_fp8.py
benchmarks/kernels/benchmark_per_token_quant_fp8.py
+272
-0
benchmarks/kernels/benchmark_quant.py
benchmarks/kernels/benchmark_quant.py
+108
-0
benchmarks/kernels/benchmark_reshape_and_cache.py
benchmarks/kernels/benchmark_reshape_and_cache.py
+172
-0
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
+210
-0
benchmarks/kernels/benchmark_rmsnorm.py
benchmarks/kernels/benchmark_rmsnorm.py
+255
-0
benchmarks/kernels/benchmark_rope.py
benchmarks/kernels/benchmark_rope.py
+108
-0
benchmarks/kernels/benchmark_shapes.py
benchmarks/kernels/benchmark_shapes.py
+94
-0
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
+720
-0
benchmarks/kernels/benchmark_trtllm_decode_attention.py
benchmarks/kernels/benchmark_trtllm_decode_attention.py
+290
-0
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
+305
-0
benchmarks/kernels/benchmark_w8a8_block_fp8.py
benchmarks/kernels/benchmark_w8a8_block_fp8.py
+415
-0
benchmarks/kernels/cpu/benchmark_cpu_attn.py
benchmarks/kernels/cpu/benchmark_cpu_attn.py
+272
-0
benchmarks/kernels/cpu/benchmark_cpu_fused_moe.py
benchmarks/kernels/cpu/benchmark_cpu_fused_moe.py
+175
-0
No files found.
Too many changes to show.
To preserve performance only
165 of 165+
files are displayed.
Plain diff
Email patch
benchmarks/kernels/benchmark_mrope.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models).
# It generates test data, runs benchmarks, and saves results to a CSV file.
#
# The CSV file (named with current date/time) contains these columns:
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
# is_neox_style, rope_parameters, dtype, torch_mean, torch_median, torch_p99,
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
# speedup
#
# == Usage Examples ==
#
# Single model benchmark:
# python3 benchmark_mrope.py --model-name Qwen/Qwen2-VL-7B-Instruct --tp-size 1 \
# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
#
# All models benchmark:
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
#
# All models with different TP sizes:
# python3 benchmark_mrope.py --model-name "" --tp-size 1 2 4 8 --warmup-iter 10 \
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
#
# All models with different token counts:
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 4096 16384
import
csv
import
os
import
time
from
datetime
import
datetime
from
typing
import
Any
import
numpy
as
np
import
torch
from
vllm.benchmarks.lib.utils
import
default_vllm_config
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.transformers_utils.config
import
get_config
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
set_random_seed
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
def
generate_test_data
(
num_tokens
:
int
,
num_q_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
max_position_embeddings
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
):
"""Generate test data for given configuration."""
# Create 2D positions (3, num_tokens) for multimodal case
positions
=
torch
.
randint
(
0
,
max_position_embeddings
//
4
,
(
3
,
num_tokens
),
device
=
device
)
# Create query and key tensors
query
=
torch
.
randn
(
num_tokens
,
num_q_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
)
key
=
torch
.
randn
(
num_tokens
,
num_kv_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
)
return
positions
,
query
,
key
def
calculate_stats
(
times
:
list
[
float
])
->
dict
[
str
,
float
]:
"""Calculate statistics from a list of times."""
times_array
=
np
.
array
(
times
)
return
{
"mean"
:
np
.
mean
(
times_array
),
"median"
:
np
.
median
(
times_array
),
"p99"
:
np
.
percentile
(
times_array
,
99
),
"min"
:
np
.
min
(
times_array
),
"max"
:
np
.
max
(
times_array
),
}
@
default_vllm_config
()
def
benchmark_mrope
(
model_name
:
str
,
num_tokens
:
int
,
head_dim
:
int
,
tp_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
8192
,
is_neox_style
:
bool
=
True
,
rope_parameters
:
dict
[
str
,
Any
]
|
None
=
None
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
seed
:
int
=
0
,
warmup_iter
:
int
=
10
,
benchmark_iter
:
int
=
100
,
csv_writer
=
None
,
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
# the parameters to compute the q k v size based on tp_size
mrope_helper_class
=
get_rope
(
head_size
=
head_dim
,
max_position
=
max_position
,
is_neox_style
=
is_neox_style
,
rope_parameters
=
rope_parameters
,
dtype
=
dtype
,
).
to
(
device
=
device
)
print
(
80
*
"="
)
print
(
f
"Evaluating model:
{
model_name
}
"
f
"with tp_size:
{
tp_size
}
"
f
"and num_tokens:
{
num_tokens
}
, "
f
"dtype:
{
dtype
}
"
)
# create q k v input tensors
# create rotary pos emb input tensors
positions
,
query
,
key
=
generate_test_data
(
num_tokens
,
num_heads
,
num_kv_heads
,
head_dim
,
max_position
,
dtype
,
device
)
# Warm up
for
_
in
range
(
warmup_iter
):
mrope_helper_class
.
forward_native
(
positions
,
query
.
clone
(),
key
.
clone
(),
)
mrope_helper_class
.
forward_cuda
(
positions
,
query
.
clone
(),
key
.
clone
(),
)
torch
.
cuda
.
synchronize
()
# Time reference implementation
torch_times
=
[]
for
_
in
range
(
benchmark_iter
):
query_clone
=
query
.
clone
()
key_clone
=
key
.
clone
()
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
mrope_helper_class
.
forward_native
(
positions
,
query_clone
,
key_clone
,
)
torch
.
cuda
.
synchronize
()
torch_times
.
append
(
time
.
time
()
-
start_time
)
# Time triton kernel implementation
triton_times
=
[]
for
_
in
range
(
benchmark_iter
):
query_clone
=
query
.
clone
()
key_clone
=
key
.
clone
()
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
mrope_helper_class
.
forward_cuda
(
positions
,
query_clone
,
key_clone
,
)
torch
.
cuda
.
synchronize
()
triton_times
.
append
(
time
.
time
()
-
start_time
)
# Calculate statistics
torch_stats
=
calculate_stats
(
torch_times
)
triton_stats
=
calculate_stats
(
triton_times
)
print
(
f
"
\n
Performance for config (
{
num_tokens
}
,
{
num_heads
}
,
{
num_kv_heads
}
):"
)
print
(
f
"Torch implementation: "
f
"mean=
{
torch_stats
[
'mean'
]:.
8
f
}
s, "
f
"median=
{
torch_stats
[
'median'
]:.
8
f
}
s, "
f
"p99=
{
torch_stats
[
'p99'
]:.
8
f
}
s"
)
print
(
f
"Triton implementation: "
f
"mean=
{
triton_stats
[
'mean'
]:.
8
f
}
s, "
f
"median=
{
triton_stats
[
'median'
]:.
8
f
}
s, "
f
"p99=
{
triton_stats
[
'p99'
]:.
8
f
}
s"
)
print
(
f
"Triton Speedup over Torch:
{
torch_stats
[
'mean'
]
/
triton_stats
[
'mean'
]:.
8
f
}
x"
)
# Write to CSV
if
csv_writer
:
row
=
[
model_name
,
tp_size
,
num_tokens
,
num_heads
,
num_kv_heads
,
head_dim
,
max_position
,
is_neox_style
,
str
(
rope_parameters
),
str
(
dtype
).
split
(
"."
)[
-
1
],
torch_stats
[
"mean"
],
torch_stats
[
"median"
],
torch_stats
[
"p99"
],
torch_stats
[
"min"
],
torch_stats
[
"max"
],
triton_stats
[
"mean"
],
triton_stats
[
"median"
],
triton_stats
[
"p99"
],
triton_stats
[
"min"
],
triton_stats
[
"max"
],
torch_stats
[
"mean"
]
/
triton_stats
[
"mean"
],
# speedup
]
csv_writer
.
writerow
(
row
)
return
torch_stats
,
triton_stats
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the rotary embedding kernels."
)
parser
.
add_argument
(
"--model-name"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--tp-size"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--warmup-iter"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--benchmark-iter"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"bfloat16"
],
default
=
"bfloat16"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
nargs
=
"+"
,
required
=
False
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--output-csv"
,
type
=
str
,
default
=
"mrope_benchmark_results.csv"
)
args
=
parser
.
parse_args
()
print
(
args
)
# Create CSV file for results
timestamp
=
datetime
.
now
().
strftime
(
"%Y%m%d_%H%M%S"
)
csv_filename
=
f
"
{
os
.
path
.
splitext
(
args
.
output_csv
)[
0
]
}
_
{
timestamp
}
.csv"
with
open
(
csv_filename
,
"w"
,
newline
=
""
)
as
csvfile
:
csv_writer
=
csv
.
writer
(
csvfile
)
# Write header
header
=
[
"model_name"
,
"tp_size"
,
"num_tokens"
,
"num_heads"
,
"num_kv_heads"
,
"head_dim"
,
"max_position"
,
"is_neox_style"
,
"rope_parameters"
,
"dtype"
,
"torch_mean"
,
"torch_median"
,
"torch_p99"
,
"torch_min"
,
"torch_max"
,
"triton_mean"
,
"triton_median"
,
"triton_p99"
,
"triton_min"
,
"triton_max"
,
"speedup"
,
]
csv_writer
.
writerow
(
header
)
model_tp_dict
=
{}
if
args
.
model_name
==
""
:
model_tp_dict
=
{
"Qwen/Qwen2-VL-2B-Instruct"
:
[
1
],
"Qwen/Qwen2-VL-7B-Instruct"
:
[
1
],
"Qwen/Qwen2-VL-72B-Instruct"
:
[
2
,
4
,
8
],
"Qwen/Qwen2.5-VL-3B-Instruct"
:
[
1
,
2
,
4
,
8
],
"Qwen/Qwen2.5-VL-7B-Instruct"
:
[
1
,
2
,
4
,
8
],
"Qwen/Qwen2.5-VL-72B-Instruct"
:
[
2
,
4
,
8
],
}
else
:
model_tp_dict
[
args
.
model_name
]
=
[
args
.
tp_size
]
if
args
.
num_tokens
is
None
:
num_tokens_list
=
[
2
**
i
for
i
in
range
(
0
,
18
)]
else
:
num_tokens_list
=
args
.
num_tokens
for
model_name
,
tp_list
in
model_tp_dict
.
items
():
config
=
get_config
(
model_name
,
trust_remote_code
=
args
.
trust_remote_code
)
for
tp_size
in
tp_list
:
# get the model config
total_num_kv_heads
=
config
.
num_key_value_heads
total_num_heads
=
config
.
num_attention_heads
num_heads
=
total_num_heads
//
tp_size
num_kv_heads
=
max
(
1
,
total_num_kv_heads
//
tp_size
)
head_dim
=
config
.
hidden_size
//
total_num_heads
q_size
=
num_heads
*
head_dim
kv_size
=
num_kv_heads
*
head_dim
is_neox_style
=
True
rope_parameters
=
config
.
rope_parameters
max_position
=
config
.
max_position_embeddings
for
num_tokens
in
num_tokens_list
:
benchmark_mrope
(
model_name
=
model_name
,
num_tokens
=
num_tokens
,
head_dim
=
head_dim
,
tp_size
=
tp_size
,
num_heads
=
num_heads
,
num_kv_heads
=
num_kv_heads
,
max_position
=
max_position
,
is_neox_style
=
is_neox_style
,
rope_parameters
=
rope_parameters
,
dtype
=
getattr
(
torch
,
args
.
dtype
),
seed
=
args
.
seed
,
warmup_iter
=
args
.
warmup_iter
,
benchmark_iter
=
args
.
benchmark_iter
,
csv_writer
=
csv_writer
,
)
print
(
f
"Benchmark results saved to
{
csv_filename
}
"
)
benchmarks/kernels/benchmark_mxfp4_qutlass.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import
argparse
import
copy
import
itertools
import
torch
from
compressed_tensors.transform.utils.hadamard
import
deterministic_hadamard_matrix
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm._custom_ops
import
fusedQuantizeMx
,
matmul_mxf4_bf16_tn
from
vllm.model_executor.layers.quantization.qutlass_utils
import
to_blocked
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"mxfp4"
:
dict
(
no_a_quant
=
False
,
enabled
=
True
),
"mxfp4-noquant"
:
dict
(
no_a_quant
=
True
,
enabled
=
True
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
get_hadamard_matrix
(
group_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
return
(
deterministic_hadamard_matrix
(
group_size
,
dtype
=
dtype
,
device
=
device
)
*
group_size
**-
0.5
)
def
_quant_weight_mxfp4
(
b
:
torch
.
Tensor
,
forward_hadamard_matrix
:
torch
.
Tensor
,
device
:
str
):
weight_hf_e2m1
,
weight_hf_e8m0
=
fusedQuantizeMx
(
b
,
forward_hadamard_matrix
,
method
=
"abs_max"
)
weight_hf_scale_block
=
to_blocked
(
weight_hf_e8m0
,
backend
=
"triton"
)
return
weight_hf_e2m1
,
weight_hf_scale_block
def
build_mxfp4_runner
(
cfg
,
a
,
b
,
forward_hadamard_matrix
,
dtype
,
device
):
weight_hf_e2m1
,
weight_hf_scale_block
=
_quant_weight_mxfp4
(
b
,
forward_hadamard_matrix
,
device
)
alpha
=
torch
.
tensor
([
1.0
],
device
=
"cuda"
)
if
cfg
[
"no_a_quant"
]:
# Pre-quantize activation
input_hf_e2m1
,
input_hf_e8m0
=
fusedQuantizeMx
(
a
,
forward_hadamard_matrix
,
method
=
"abs_max"
)
input_hf_scale_block
=
to_blocked
(
input_hf_e8m0
,
backend
=
"triton"
)
def
run
():
return
matmul_mxf4_bf16_tn
(
input_hf_e2m1
,
weight_hf_e2m1
,
input_hf_scale_block
,
weight_hf_scale_block
,
alpha
,
)
return
run
# Quantize activation on-the-fly
def
run
():
input_hf_e2m1
,
input_hf_e8m0
=
fusedQuantizeMx
(
a
,
forward_hadamard_matrix
,
method
=
"abs_max"
)
input_hf_scale_block
=
to_blocked
(
input_hf_e8m0
,
backend
=
"triton"
)
return
matmul_mxf4_bf16_tn
(
input_hf_e2m1
,
weight_hf_e2m1
,
input_hf_scale_block
,
weight_hf_scale_block
,
alpha
,
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
24576
,
32768
,
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs MXFP4 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
,
had_size
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
forward_hadamard_matrix
=
get_hadamard_matrix
(
had_size
,
dtype
,
device
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
rep
=
200
,
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_mxfp4_runner
(
cfg
,
a
,
b
,
forward_hadamard_matrix
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
rep
=
200
,
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.3-70B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
for
had_size
in
[
32
,
64
,
128
]:
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, HAD=
{
had_size
}
, BF16 vs MXFP4 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_mxfp4_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
had_size
=
had_size
,
)
print
(
"Benchmark finished!"
)
benchmarks/kernels/benchmark_nvfp4_gemm.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
os
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.triton_utils
import
triton
if
not
current_platform
.
has_device_capability
(
100
):
raise
RuntimeError
(
"NVFP4 requires compute capability of 10.0 (Blackwell)"
)
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"nvfp4"
:
dict
(
no_a_quant
=
False
,
enabled
=
True
),
"nvfp4-noquant"
:
dict
(
no_a_quant
=
True
,
enabled
=
True
),
"fbgemm-nvfp4"
:
dict
(
fbgemm
=
True
,
no_a_quant
=
False
,
enabled
=
True
),
"fbgemm-nvfp4-noquant"
:
dict
(
fbgemm
=
True
,
no_a_quant
=
True
,
enabled
=
True
),
}
_needs_fbgemm
=
any
(
v
.
get
(
"fbgemm"
,
False
)
for
v
in
PROVIDER_CFGS
.
values
()
if
v
.
get
(
"enabled"
,
False
)
)
if
_needs_fbgemm
:
try
:
from
fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize
import
(
triton_scale_nvfp4_quant
,
)
except
ImportError
:
print
(
"WARNING: FBGEMM providers are enabled but fbgemm_gpu is not installed. "
"These providers will be skipped. Please install fbgemm_gpu with: "
"'pip install fbgemm-gpu-genai' to run them."
)
# Disable FBGEMM providers so the benchmark can run.
for
cfg
in
PROVIDER_CFGS
.
values
():
if
cfg
.
get
(
"fbgemm"
):
cfg
[
"enabled"
]
=
False
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
_quant_weight_nvfp4
(
b
:
torch
.
Tensor
,
device
:
str
,
cfg
):
# Compute global scale for weight
b_amax
=
torch
.
abs
(
b
).
max
().
to
(
torch
.
float32
)
b_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
b_amax
if
"fbgemm"
in
cfg
and
cfg
[
"fbgemm"
]:
b_fp4
,
scale_b_fp4
=
triton_scale_nvfp4_quant
(
b
,
b_global_scale
)
else
:
b_fp4
,
scale_b_fp4
=
ops
.
scaled_fp4_quant
(
b
,
b_global_scale
)
return
b_fp4
,
scale_b_fp4
,
b_global_scale
def
build_nvfp4_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
b_fp4
,
scale_b_fp4
,
b_global_scale
=
_quant_weight_nvfp4
(
b
,
device
,
cfg
)
# Compute global scale for activation
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
a_amax
=
torch
.
abs
(
a
).
max
().
to
(
torch
.
float32
)
a_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
a_amax
# Alpha for the GEMM operation
alpha
=
1.0
/
(
a_global_scale
*
b_global_scale
)
if
"fbgemm"
in
cfg
and
cfg
[
"fbgemm"
]:
if
cfg
[
"no_a_quant"
]:
a_fp4
,
scale_a_fp4
=
triton_scale_nvfp4_quant
(
a
,
a_global_scale
)
def
run
():
return
torch
.
ops
.
fbgemm
.
f4f4bf16
(
a_fp4
,
b_fp4
,
scale_a_fp4
,
scale_b_fp4
,
global_scale
=
alpha
,
use_mx
=
False
,
)
return
run
else
:
def
run
():
a_fp4
,
scale_a_fp4
=
triton_scale_nvfp4_quant
(
a
,
a_global_scale
)
return
torch
.
ops
.
fbgemm
.
f4f4bf16
(
a_fp4
,
b_fp4
,
scale_a_fp4
,
scale_b_fp4
,
global_scale
=
alpha
,
use_mx
=
False
,
)
return
run
if
cfg
[
"no_a_quant"
]:
# Pre-quantize activation
a_fp4
,
scale_a_fp4
=
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
)
def
run
():
return
ops
.
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
scale_a_fp4
,
scale_b_fp4
,
alpha
,
dtype
)
return
run
# Quantize activation on-the-fly
def
run
():
a_fp4
,
scale_a_fp4
=
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
)
return
ops
.
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
scale_a_fp4
,
scale_b_fp4
,
alpha
,
dtype
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs NVFP4 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_nvfp4_runner
(
cfg
,
a
,
b
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs NVFP4 GEMMs TFLOP/s:"
)
save_dir
=
f
"bench_nvfp4_res_n
{
N
}
_k
{
K
}
"
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
save_dir
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
benchmarks/kernels/benchmark_nvfp4_quant.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.triton_utils
import
triton
from
vllm.utils.flashinfer
import
flashinfer_fp4_quantize
if
not
current_platform
.
has_device_capability
(
100
):
raise
RuntimeError
(
"NVFP4 requires compute capability of 10.0 (Blackwell)"
)
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
PROVIDER_CFGS
=
{
"vllm"
:
dict
(
backend
=
"vllm"
,
is_sf_swizzled_layout
=
False
,
enabled
=
True
),
"vllm-swizzle"
:
dict
(
backend
=
"vllm"
,
is_sf_swizzled_layout
=
True
,
enabled
=
True
),
"flashinfer"
:
dict
(
backend
=
"flashinfer"
,
is_sf_swizzled_layout
=
False
,
enabled
=
True
),
"flashinfer-swizzle"
:
dict
(
backend
=
"flashinfer"
,
is_sf_swizzled_layout
=
True
,
enabled
=
True
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
compute_global_scale
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute global scale for FP4 quantization."""
amax
=
torch
.
abs
(
tensor
).
max
().
to
(
torch
.
float32
)
return
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
amax
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"us (lower is better)"
,
plot_name
=
"NVFP4 Input Quantization Latency (us)"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
# Create input tensor
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
# Compute global scale for activation
a_global_scale
=
compute_global_scale
(
a
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
cfg
=
PROVIDER_CFGS
[
provider
]
if
cfg
[
"backend"
]
==
"vllm"
:
# vLLM's FP4 quantization
if
cfg
[
"is_sf_swizzled_layout"
]:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
True
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
False
),
quantiles
=
quantiles
,
)
elif
cfg
[
"backend"
]
==
"flashinfer"
:
# FlashInfer's FP4 quantization
if
cfg
[
"is_sf_swizzled_layout"
]:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
flashinfer_fp4_quantize
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
True
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
flashinfer_fp4_quantize
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
False
),
quantiles
=
quantiles
,
)
# Convert ms to us for better readability at small batch sizes
to_us
=
lambda
t_ms
:
t_ms
*
1000
return
to_us
(
ms
),
to_us
(
max_ms
),
to_us
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
def
_test_accuracy_once
(
M
:
int
,
K
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
is_sf_swizzled_layout
:
bool
):
"""Test accuracy between vLLM and FlashInfer FP4 quantization."""
# Create input tensor
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
# Compute global scale
a_global_scale
=
compute_global_scale
(
a
)
# vLLM quantization
vllm_fp4
,
vllm_scale
=
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
is_sf_swizzled_layout
)
# FlashInfer quantization (with swizzled layout to match vLLM's output)
flashinfer_fp4
,
flashinfer_scale
=
flashinfer_fp4_quantize
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
is_sf_swizzled_layout
)
flashinfer_scale
=
flashinfer_scale
.
view
(
torch
.
float8_e4m3fn
)
# Compare outputs
torch
.
testing
.
assert_close
(
vllm_fp4
,
flashinfer_fp4
,
)
# Compare scales
torch
.
testing
.
assert_close
(
vllm_scale
,
flashinfer_scale
,
)
print
(
f
"M=
{
M
}
, K=
{
K
}
, dtype=
{
dtype
}
, is_sf_swizzled_layout=
{
is_sf_swizzled_layout
}
: PASSED"
# noqa: E501
)
def
test_accuracy
():
"""Run accuracy tests across various shapes."""
print
(
"
\n
"
+
"="
*
60
)
print
(
"Running accuracy tests: vLLM vs FlashInfer"
)
print
(
"="
*
60
)
device
=
"cuda"
dtype
=
torch
.
bfloat16
# Test various batch sizes and hidden dimensions
Ms
=
[
1
,
1024
]
Ks
=
[
4096
]
for
is_sf_swizzled_layout
in
[
True
,
False
]:
for
M
in
Ms
:
for
K
in
Ks
:
_test_accuracy_once
(
M
,
K
,
dtype
,
device
,
is_sf_swizzled_layout
)
print
(
"
\n
All accuracy tests passed!"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark NVFP4 quantization: vLLM vs FlashInfer"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.3-70B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
None
,
help
=
"Path to save benchmark results"
,
)
parser
.
add_argument
(
"--accuracy"
,
action
=
"store_true"
,
help
=
"Run accuracy tests"
,
)
args
=
parser
.
parse_args
()
if
args
.
accuracy
:
test_accuracy
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
\n
{
model
}
, N=
{
N
}
K=
{
K
}
"
)
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
,
N
=
N
,
K
=
K
,
)
print
(
"
\n
Benchmark finished!"
)
benchmarks/kernels/benchmark_nvfp4_qutlass.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import
argparse
import
copy
import
itertools
import
torch
from
compressed_tensors.transform.utils.hadamard
import
deterministic_hadamard_matrix
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
# use existing nvfp4 gemm in vllm
from
vllm._custom_ops
import
fusedQuantizeNv
from
vllm.model_executor.layers.quantization.qutlass_utils
import
to_blocked
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"nvfp4"
:
dict
(
no_a_quant
=
False
,
enabled
=
True
),
"nvfp4-noquant"
:
dict
(
no_a_quant
=
True
,
enabled
=
True
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
get_hadamard_matrix
(
group_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
return
(
deterministic_hadamard_matrix
(
group_size
,
dtype
=
dtype
,
device
=
device
)
*
group_size
**-
0.5
)
def
_quant_weight_nvfp4
(
b
:
torch
.
Tensor
,
forward_hadamard_matrix
:
torch
.
Tensor
,
global_scale
:
torch
.
Tensor
,
device
:
str
,
M
:
int
,
N
:
int
,
K
:
int
,
):
weight_hf_e2m1
,
weight_hf_e8m0
=
fusedQuantizeNv
(
b
,
forward_hadamard_matrix
,
global_scale
)
weight_hf_scale_block
=
to_blocked
(
weight_hf_e8m0
,
backend
=
"triton"
).
view
(
-
1
,
K
//
16
)
return
weight_hf_e2m1
,
weight_hf_scale_block
def
build_nvfp4_runner
(
cfg
,
a
,
b
,
forward_hadamard_matrix
,
dtype
,
device
,
M
,
N
,
K
):
alpha
=
torch
.
tensor
([
1.0
],
device
=
"cuda"
)
global_scale
=
torch
.
tensor
([
1.0
],
device
=
"cuda"
)
weight_hf_e2m1
,
weight_hf_scale_block
=
_quant_weight_nvfp4
(
b
,
forward_hadamard_matrix
,
global_scale
,
device
,
M
,
N
,
K
)
if
cfg
[
"no_a_quant"
]:
# Pre-quantize activation
input_hf_e2m1
,
input_hf_e8m0
=
fusedQuantizeNv
(
a
,
forward_hadamard_matrix
,
global_scale
)
input_hf_scale_block
=
to_blocked
(
input_hf_e8m0
,
backend
=
"triton"
).
view
(
-
1
,
K
//
16
)
def
run
():
return
ops
.
cutlass_scaled_fp4_mm
(
input_hf_e2m1
,
weight_hf_e2m1
,
input_hf_scale_block
,
weight_hf_scale_block
,
alpha
,
torch
.
bfloat16
,
)
return
run
# Quantize activation on-the-fly
def
run
():
input_hf_e2m1
,
input_hf_e8m0
=
fusedQuantizeNv
(
a
,
forward_hadamard_matrix
,
global_scale
)
input_hf_scale_block
=
to_blocked
(
input_hf_e8m0
,
backend
=
"triton"
).
view
(
-
1
,
K
//
16
)
return
ops
.
cutlass_scaled_fp4_mm
(
input_hf_e2m1
,
weight_hf_e2m1
,
input_hf_scale_block
,
weight_hf_scale_block
,
alpha
,
torch
.
bfloat16
,
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
24576
,
32768
,
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs NVFP4 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
,
had_size
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
forward_hadamard_matrix
=
get_hadamard_matrix
(
had_size
,
dtype
,
device
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
rep
=
200
,
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_nvfp4_runner
(
cfg
,
a
,
b
,
forward_hadamard_matrix
,
dtype
,
device
,
M
,
N
,
K
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
rep
=
200
,
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.3-70B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
for
had_size
in
[
16
,
32
,
64
,
128
]:
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, HAD=
{
had_size
}
, BF16 vs NVFP4 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_nvfp4_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
had_size
=
had_size
,
)
print
(
"Benchmark finished!"
)
benchmarks/kernels/benchmark_paged_attention.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
time
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
create_kv_caches_with_random
,
set_random_seed
,
)
logger
=
init_logger
(
__name__
)
NUM_BLOCKS
=
128
*
1024
PARTITION_SIZE
=
512
PARTITION_SIZE_ROCM
=
256
@
torch
.
inference_mode
()
def
main
(
version
:
str
,
num_seqs
:
int
,
seq_len
:
int
,
num_query_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
do_profile
:
bool
,
device
:
str
=
"cuda"
,
kv_cache_dtype
:
str
|
None
=
None
,
)
->
None
:
set_random_seed
(
seed
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
alibi_slopes
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
device
)
seq_lens
=
[
seq_len
for
_
in
range
(
num_seqs
)]
max_seq_len
=
max
(
seq_lens
)
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables_lst
:
list
[
list
[
int
]]
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables_lst
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables_lst
,
dtype
=
torch
.
int
,
device
=
device
)
# Create the KV cache.
key_caches
,
value_caches
=
create_kv_caches_with_random
(
NUM_BLOCKS
,
block_size
,
1
,
num_kv_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
device
=
device
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Prepare for the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
if
version
==
"v2"
:
if
current_platform
.
is_rocm
():
global
PARTITION_SIZE
if
not
args
.
custom_paged_attn
and
not
current_platform
.
is_navi
():
PARTITION_SIZE
=
1024
else
:
PARTITION_SIZE
=
PARTITION_SIZE_ROCM
num_partitions
=
(
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
def
run_cuda_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
perf_counter
()
# Using default kv_scale
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
for
_
in
range
(
num_iters
):
if
version
==
"v1"
:
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
elif
version
==
"v2"
:
if
not
args
.
custom_paged_attn
:
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
ops
.
paged_attention_rocm
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
None
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
raise
ValueError
(
f
"Invalid version:
{
version
}
"
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
print
(
"Warming up..."
)
run_benchmark
=
run_cuda_benchmark
run_benchmark
(
num_iters
=
3
,
profile
=
False
)
# Benchmark.
if
do_profile
:
latency
=
run_benchmark
(
num_iters
=
1
,
profile
=
True
)
else
:
latency
=
run_benchmark
(
num_iters
=
100
,
profile
=
False
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
"__main__"
:
logger
.
warning
(
"This script benchmarks the paged attention kernel. "
"By default this is no longer used in vLLM inference."
)
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the paged attention kernel."
)
parser
.
add_argument
(
"--version"
,
type
=
str
,
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--use-alibi"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--kv-cache-dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8"
,
"fp8_e5m2"
,
"fp8_e4m3"
],
default
=
"auto"
,
help
=
"Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)"
,
)
parser
.
add_argument
(
"--custom-paged-attn"
,
action
=
"store_true"
,
help
=
"Use custom paged attention"
)
args
=
parser
.
parse_args
()
print
(
args
)
if
args
.
num_query_heads
%
args
.
num_kv_heads
!=
0
:
raise
ValueError
(
"num_query_heads must be divisible by num_kv_heads"
)
main
(
version
=
args
.
version
,
num_seqs
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
num_query_heads
=
args
.
num_query_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
head_size
=
args
.
head_size
,
block_size
=
args
.
block_size
,
use_alibi
=
args
.
use_alibi
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
kv_cache_dtype
=
args
.
kv_cache_dtype
,
)
benchmarks/kernels/benchmark_per_token_group_quant.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
math
from
collections.abc
import
Callable
from
contextlib
import
contextmanager
from
unittest.mock
import
patch
import
torch
from
vllm.model_executor.layers.quantization.utils
import
fp8_utils
,
int8_utils
from
vllm.platforms
import
current_platform
@
contextmanager
def
_triton_mode
():
"""Temporarily force the Triton fallback path"""
with
patch
(
"vllm.platforms.current_platform.is_cuda"
,
return_value
=
False
):
yield
def
_time_cuda
(
fn
:
Callable
[[],
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
warmup_iters
:
int
,
bench_iters
:
int
,
)
->
float
:
# warmup
for
_
in
range
(
warmup_iters
):
fn
()
torch
.
cuda
.
synchronize
()
start
=
torch
.
Event
(
enable_timing
=
True
)
end
=
torch
.
Event
(
enable_timing
=
True
)
start
.
record
()
for
_
in
range
(
bench_iters
):
fn
()
end
.
record
()
torch
.
cuda
.
synchronize
()
return
start
.
elapsed_time
(
end
)
/
bench_iters
# ms/iter
def
_run_single
(
shape
:
tuple
[
int
,
int
],
group_size
:
int
,
dtype
:
str
,
*
,
column_major
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
warmup_iters
:
int
,
bench_iters
:
int
,
)
->
None
:
num_tokens
,
hidden_dim
=
shape
device
=
torch
.
device
(
"cuda"
)
torch
.
manual_seed
(
42
)
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
*
8
if
dtype
==
"fp8"
:
def
cuda_impl
():
return
fp8_utils
.
per_token_group_quant_fp8
(
x
,
group_size
,
column_major_scales
=
column_major
,
use_ue8m0
=
scale_ue8m0
,
)
def
triton_impl
():
with
_triton_mode
():
return
fp8_utils
.
per_token_group_quant_fp8
(
x
,
group_size
,
column_major_scales
=
column_major
,
use_ue8m0
=
scale_ue8m0
,
)
elif
dtype
==
"int8"
:
def
cuda_impl
():
return
int8_utils
.
per_token_group_quant_int8
(
x
,
group_size
)
def
triton_impl
():
with
_triton_mode
():
return
int8_utils
.
per_token_group_quant_int8
(
x
,
group_size
)
else
:
raise
ValueError
(
"dtype must be 'fp8' or 'int8'"
)
cuda_ms
=
_time_cuda
(
cuda_impl
,
warmup_iters
,
bench_iters
)
triton_ms
=
_time_cuda
(
triton_impl
,
warmup_iters
,
bench_iters
)
speedup
=
triton_ms
/
cuda_ms
if
cuda_ms
else
math
.
inf
cfg_desc
=
(
f
"shape=
{
shape
}
gs=
{
group_size
:
<
3
}
col_major=
{
column_major
:
<
5
}
"
f
"ue8m0=
{
scale_ue8m0
:
<
5
}
dtype=
{
dtype
}
"
)
print
(
f
"
{
cfg_desc
:
55
}
| CUDA
{
cuda_ms
:
7.3
f
}
ms | Triton
{
triton_ms
:
7.3
f
}
ms | "
f
"speed-up ×
{
speedup
:
5.2
f
}
"
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--warmup-iters"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--bench-iters"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--dtype"
,
choices
=
[
"fp8"
,
"int8"
,
"both"
],
default
=
"both"
)
return
parser
.
parse_args
()
if
__name__
==
"__main__"
:
if
not
current_platform
.
is_cuda
():
raise
RuntimeError
(
"CUDA device is required to run this benchmark."
)
args
=
parse_args
()
warmup_iters
,
bench_iters
=
args
.
warmup_iters
,
args
.
bench_iters
shapes
=
[(
32
,
128
),
(
64
,
256
),
(
16
,
512
)]
group_sizes
=
[
64
,
128
]
dtypes
=
[
"fp8"
,
"int8"
]
if
args
.
dtype
==
"both"
else
[
args
.
dtype
]
header
=
(
"Configuration"
.
ljust
(
55
)
+
" | "
+
"CUDA (ms)"
.
center
(
12
)
+
" | "
+
"Triton (ms)"
.
center
(
13
)
+
" | "
+
"Speed-up"
)
print
(
header
)
print
(
"-"
*
len
(
header
))
for
dtype
in
dtypes
:
for
shape
in
shapes
:
for
gs
in
group_sizes
:
if
dtype
==
"fp8"
:
for
col_major
in
(
False
,
True
):
for
ue8m0
in
(
False
,
True
):
_run_single
(
shape
,
gs
,
dtype
,
column_major
=
col_major
,
scale_ue8m0
=
ue8m0
,
warmup_iters
=
warmup_iters
,
bench_iters
=
bench_iters
,
)
else
:
# INT8 has no col-major / ue8m0 switches
_run_single
(
shape
,
gs
,
dtype
,
warmup_iters
=
warmup_iters
,
bench_iters
=
bench_iters
,
)
benchmarks/kernels/benchmark_per_token_quant_fp8.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
from
collections.abc
import
Callable
from
unittest.mock
import
patch
import
pandas
as
pd
import
torch
from
vllm.benchmarks.lib.utils
import
default_vllm_config
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.triton_utils
import
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
def
with_triton_mode
(
fn
):
"""Temporarily force the Triton fallback path"""
def
wrapped
(
*
args
,
**
kwargs
):
with
patch
(
"vllm.platforms.current_platform.is_cuda"
,
return_value
=
False
):
return
fn
(
*
args
,
**
kwargs
)
return
wrapped
# TODO(luka): use standalone_compile utility
def
with_dyn_arg
(
fn
:
Callable
,
arg_index
:
int
,
dim_index
:
int
):
def
inner
(
*
args
):
torch
.
_dynamo
.
mark_dynamic
(
args
[
arg_index
],
dim_index
)
return
fn
(
*
args
)
return
inner
def
bench_compile
(
fn
:
Callable
):
# recompile for different shapes
fwd
=
torch
.
compile
(
fn
,
fullgraph
=
True
,
dynamic
=
False
)
# First dim is explicitly dynamic to simulate vLLM usage
return
with_dyn_arg
(
fwd
,
0
,
0
)
torch
.
_dynamo
.
config
.
recompile_limit
=
8888
def
calculate_diff
(
batch_size
:
int
,
hidden_size
:
int
,
group_shape
:
GroupShape
,
dtype
:
torch
.
dtype
,
):
"""Calculate the difference between Inductor and CUDA implementations."""
device
=
torch
.
device
(
"cuda"
)
x
=
torch
.
randn
((
batch_size
,
hidden_size
),
dtype
=
dtype
,
device
=
device
)
quant_fp8
=
QuantFP8
(
False
,
group_shape
,
column_major_scales
=
False
)
torch_out
,
torch_scale
=
bench_compile
(
quant_fp8
.
forward_native
)(
x
)
torch_eager_out
,
torch_eager_scale
=
quant_fp8
.
forward_native
(
x
)
cuda_out
,
cuda_scale
=
quant_fp8
.
forward_cuda
(
x
)
try
:
torch
.
testing
.
assert_close
(
cuda_out
.
to
(
torch
.
float32
),
torch_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
,
)
torch
.
testing
.
assert_close
(
cuda_scale
,
torch_scale
,
rtol
=
1e-3
,
atol
=
1e-5
)
torch
.
testing
.
assert_close
(
cuda_out
.
to
(
torch
.
float32
),
torch_eager_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
,
)
torch
.
testing
.
assert_close
(
cuda_scale
,
torch_eager_scale
,
rtol
=
1e-3
,
atol
=
1e-5
)
print
(
"✅ All implementations match"
)
except
AssertionError
as
e
:
print
(
"❌ Implementations differ"
)
print
(
e
)
configs
=
[]
@
default_vllm_config
()
def
benchmark_quantization
(
batch_size
,
hidden_size
,
provider
,
group_shape
:
GroupShape
,
col_major
:
bool
,
dtype
:
torch
.
dtype
,
):
device
=
torch
.
device
(
"cuda"
)
x
=
torch
.
randn
(
batch_size
,
hidden_size
,
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
quant_fp8
=
QuantFP8
(
False
,
group_shape
,
column_major_scales
=
col_major
)
if
provider
==
"torch"
:
fn
=
lambda
:
bench_compile
(
quant_fp8
.
forward_native
)(
x
.
clone
())
elif
provider
==
"cuda"
:
fn
=
lambda
:
quant_fp8
.
forward_cuda
(
x
.
clone
())
elif
provider
==
"triton"
:
if
not
group_shape
.
is_per_group
():
# Triton only supported for per-group
return
0
,
0
,
0
fn
=
lambda
:
with_triton_mode
(
quant_fp8
.
forward_cuda
)(
x
.
clone
())
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
fn
,
quantiles
=
quantiles
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
# TODO(luka) extract to utils
def
compute_geomean_speedups
(
df
:
pd
.
DataFrame
,
baseline_col
:
str
,
speedup_cols
:
list
[
str
],
groupby_cols
:
list
[
str
]
|
None
=
None
,
)
->
pd
.
DataFrame
:
"""
Compute geometric mean speedups over a baseline column.
Args:
df: Input dataframe
baseline_col: Column to use as baseline
speedup_cols: Columns to compute speedups for
groupby_cols: Columns to group by. If None, compute over entire df.
Returns:
pd.DataFrame with geometric mean speedups
"""
from
scipy.stats
import
gmean
def
geo_speedup
(
group
:
pd
.
DataFrame
)
->
pd
.
Series
:
ratios
=
{
col
:
(
group
[
baseline_col
]
/
group
[
col
]).
values
for
col
in
speedup_cols
}
return
pd
.
Series
({
col
:
gmean
(
vals
)
for
col
,
vals
in
ratios
.
items
()})
if
groupby_cols
is
None
:
result
=
geo_speedup
(
df
).
to_frame
().
T
else
:
result
=
(
df
.
groupby
(
groupby_cols
)
.
apply
(
geo_speedup
,
include_groups
=
False
)
.
reset_index
()
)
return
result
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the various implementations of QuantFP8 (dynamic-only)"
)
parser
.
add_argument
(
"-c"
,
"--check"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"bfloat16"
)
parser
.
add_argument
(
"--hidden-sizes"
,
type
=
int
,
nargs
=
"+"
,
default
=
[
896
,
1024
,
2048
,
4096
,
7168
],
help
=
"Hidden sizes to benchmark"
,
)
parser
.
add_argument
(
"--batch-sizes"
,
type
=
int
,
nargs
=
"+"
,
default
=
[
1
,
16
,
128
,
512
,
1024
],
help
=
"Batch sizes to benchmark"
,
)
parser
.
add_argument
(
"--group-sizes"
,
type
=
int
,
nargs
=
"+"
,
default
=
None
,
help
=
"Group sizes for GroupShape(1,N) to benchmark. "
"Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)"
,
)
parser
.
add_argument
(
"--no-column-major"
,
action
=
"store_true"
,
help
=
"Disable column-major scales testing"
,
)
args
=
parser
.
parse_args
()
assert
args
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
]
hidden_sizes
=
args
.
hidden_sizes
batch_sizes
=
args
.
batch_sizes
if
args
.
group_sizes
is
not
None
:
group_shapes
=
[]
for
size
in
args
.
group_sizes
:
if
size
==
0
:
group_shapes
.
append
(
GroupShape
.
PER_TENSOR
)
elif
size
==
-
1
:
group_shapes
.
append
(
GroupShape
.
PER_TOKEN
)
else
:
group_shapes
.
append
(
GroupShape
(
1
,
size
))
else
:
group_shapes
=
[
GroupShape
.
PER_TENSOR
,
GroupShape
.
PER_TOKEN
,
GroupShape
(
1
,
64
),
GroupShape
(
1
,
128
),
]
column_major_scales
=
[
False
]
if
args
.
no_column_major
else
[
True
,
False
]
config_gen
=
itertools
.
product
(
group_shapes
,
column_major_scales
,
batch_sizes
,
hidden_sizes
,
)
# filter out column-major scales for non-group, reverse order
configs
.
extend
(
c
[::
-
1
]
for
c
in
config_gen
if
(
c
[
0
].
is_per_group
()
or
not
c
[
1
]))
print
(
f
"Running
{
len
(
configs
)
}
configurations:"
)
print
(
f
" Hidden sizes:
{
hidden_sizes
}
"
)
print
(
f
" Batch sizes:
{
batch_sizes
}
"
)
print
(
f
" Group shapes:
{
[
str
(
g
)
for
g
in
group_shapes
]
}
"
)
print
(
f
" Column major scales:
{
column_major_scales
}
"
)
print
()
if
args
.
check
:
for
group_shape
in
group_shapes
:
group_size
=
group_shape
[
1
]
print
(
f
"
{
group_size
=
}
"
)
calculate_diff
(
batch_size
=
4
,
hidden_size
=
4096
,
group_shape
=
group_shape
,
dtype
=
dtype
)
benchmark
=
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"hidden_size"
,
"batch_size"
,
"col_major"
,
"group_shape"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"torch"
,
"cuda"
,
"triton"
],
line_names
=
[
"Torch (Compiled)"
,
"CUDA"
,
"Triton"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"black"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"QuantFP8 performance"
,
args
=
{},
)
)(
benchmark_quantization
)
df
=
benchmark
.
run
(
print_data
=
True
,
dtype
=
dtype
,
return_df
=
True
)
# Print geomean speedups
geo_table_grouped
=
compute_geomean_speedups
(
df
,
baseline_col
=
"Torch (Compiled)"
,
speedup_cols
=
[
"CUDA"
,
"Triton"
],
groupby_cols
=
[
"col_major"
,
"group_shape"
],
)
print
(
"Speedup over Torch (Compiled)"
)
print
(
geo_table_grouped
.
to_string
(
index
=
False
))
benchmarks/kernels/benchmark_quant.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
set_random_seed
@
torch
.
inference_mode
()
def
main
(
num_tokens
:
int
,
hidden_size
:
int
,
static_scale
:
bool
,
quant_dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
do_profile
:
bool
=
False
,
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
,
)
->
None
:
set_random_seed
(
seed
)
torch
.
set_default_device
(
"cuda"
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
scale
=
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float32
)
if
static_scale
else
None
def
run_cuda_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
num_iters
):
if
quant_dtype
==
torch
.
int8
:
ops
.
scaled_int8_quant
(
x
,
scale
)
else
:
ops
.
scaled_fp8_quant
(
x
,
scale
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
print
(
"Warming up..."
)
run_benchmark
=
run_cuda_benchmark
run_benchmark
(
num_iters
=
num_warmup_iters
,
profile
=
False
)
# Benchmark.
if
do_profile
:
latency
=
run_benchmark
(
num_iters
=
1
,
profile
=
True
)
else
:
latency
=
run_benchmark
(
num_iters
=
num_iters
,
profile
=
False
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
"__main__"
:
def
to_torch_dtype
(
dt
):
if
dt
==
"int8"
:
return
torch
.
int8
if
dt
==
"fp8"
:
return
torch
.
float8_e4m3fn
raise
ValueError
(
f
"Unsupported dtype:
{
dt
}
"
)
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the quantization (fp8 or int8) kernel."
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--static-scale"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--quant-dtype"
,
type
=
str
,
choices
=
[
"fp8"
,
"int8"
],
default
=
"int8"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-warmup-iters"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-iters"
,
type
=
int
,
default
=
100
,
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
,
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
num_tokens
=
args
.
num_tokens
,
hidden_size
=
args
.
hidden_size
,
static_scale
=
args
.
static_scale
,
quant_dtype
=
to_torch_dtype
(
args
.
quant_dtype
),
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
,
)
benchmarks/kernels/benchmark_reshape_and_cache.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
time
import
torch
from
tabulate
import
tabulate
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
create_kv_caches_with_random
,
set_random_seed
,
)
logger
=
init_logger
(
__name__
)
@
torch
.
inference_mode
()
def
run_benchmark
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
num_iters
:
int
,
benchmark_mode
:
str
,
device
:
str
=
"cuda"
,
)
->
float
:
"""Return latency (seconds) for given num_tokens."""
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
raise
ValueError
(
"fp8 kv-cache requires head_size to be a multiple of 16."
)
set_random_seed
(
42
)
torch
.
set_default_device
(
device
)
# create random key / value tensors [T, H, D].
key
=
torch
.
randn
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
randn_like
(
key
)
# prepare the slot mapping.
# each token is assigned a unique slot in the KV-cache.
num_slots
=
block_size
*
num_blocks
if
num_tokens
>
num_slots
:
raise
ValueError
(
"num_tokens cannot exceed the total number of cache slots"
)
slot_mapping_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_lst
,
dtype
=
torch
.
long
,
device
=
device
)
key_caches
,
value_caches
=
create_kv_caches_with_random
(
num_blocks
,
block_size
,
1
,
# num_layers
num_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
device
=
device
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# to free unused memory
del
key_caches
,
value_caches
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
function_under_test
=
lambda
:
ops
.
reshape_and_cache
(
key
,
# noqa: F821
value
,
# noqa: F821
key_cache
,
# noqa: F821
value_cache
,
# noqa: F821
slot_mapping
,
# noqa: F821
kv_cache_dtype
,
k_scale
,
v_scale
,
)
if
benchmark_mode
==
"cudagraph"
:
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
function_under_test
()
torch
.
cuda
.
synchronize
()
function_under_test
=
lambda
:
g
.
replay
()
def
run_cuda_benchmark
(
n_iters
:
int
)
->
float
:
nonlocal
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
torch
.
cuda
.
synchronize
()
start
=
time
.
perf_counter
()
for
_
in
range
(
n_iters
):
function_under_test
()
torch
.
cuda
.
synchronize
()
end
=
time
.
perf_counter
()
return
(
end
-
start
)
/
n_iters
# warm-up
run_cuda_benchmark
(
3
)
lat
=
run_cuda_benchmark
(
num_iters
)
# free tensors to mitigate OOM when sweeping
del
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
torch
.
cuda
.
empty_cache
()
return
lat
def
main
(
args
):
rows
=
[]
for
exp
in
range
(
1
,
17
):
n_tok
=
2
**
exp
lat
=
run_benchmark
(
num_tokens
=
n_tok
,
num_heads
=
args
.
num_heads
,
head_size
=
args
.
head_size
,
block_size
=
args
.
block_size
,
num_blocks
=
args
.
num_blocks
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
kv_cache_dtype
=
args
.
kv_cache_dtype
,
num_iters
=
args
.
iters
,
benchmark_mode
=
args
.
mode
,
device
=
"cuda"
,
)
rows
.
append
([
n_tok
,
lat
*
1e6
])
# convert to microseconds
print
(
f
"Benchmark results for implementation cuda (measuring with
{
args
.
mode
}
):"
)
print
(
tabulate
(
rows
,
headers
=
[
"num_tokens"
,
"latency (µs)"
],
floatfmt
=
".3f"
))
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--num-blocks"
,
type
=
int
,
default
=
128
*
128
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"bfloat16"
,
)
parser
.
add_argument
(
"--kv-cache-dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8"
],
default
=
"auto"
,
)
parser
.
add_argument
(
"--iters"
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
"--mode"
,
type
=
str
,
choices
=
[
"cudagraph"
,
"no_graph"
],
default
=
"cudagraph"
,
)
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
time
import
torch
from
tabulate
import
tabulate
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
create_kv_caches_with_random_flash
,
set_random_seed
,
)
from
vllm.v1.attention.ops.triton_reshape_and_cache_flash
import
(
triton_reshape_and_cache_flash
,
)
logger
=
init_logger
(
__name__
)
@
torch
.
inference_mode
()
def
run_benchmark
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
kv_cache_layout
:
str
,
num_iters
:
int
,
implementation
:
str
,
benchmark_mode
:
str
,
device
:
str
=
"cuda"
,
)
->
float
:
"""Return latency (seconds) for given num_tokens."""
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
raise
ValueError
(
"fp8 kv-cache requires head_size to be a multiple of 16."
)
if
implementation
not
in
(
"cuda"
,
"triton"
):
raise
ValueError
(
f
"Unsupported implementation:
{
implementation
}
. "
"Only 'cuda' and 'triton' are supported."
)
if
implementation
==
"triton"
and
kv_cache_layout
==
"HND"
:
return
float
(
"nan"
)
# Triton does not support HND layout yet.
set_random_seed
(
42
)
torch
.
set_default_device
(
device
)
# create random key / value tensors [T, H, D].
key
=
torch
.
randn
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
randn_like
(
key
)
# prepare the slot mapping.
# each token is assigned a unique slot in the KV-cache.
num_slots
=
block_size
*
num_blocks
if
num_tokens
>
num_slots
:
raise
ValueError
(
"num_tokens cannot exceed the total number of cache slots"
)
slot_mapping_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_lst
,
dtype
=
torch
.
long
,
device
=
device
)
key_caches
,
value_caches
=
create_kv_caches_with_random_flash
(
num_blocks
,
block_size
,
1
,
# num_layers
num_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
device
=
device
,
cache_layout
=
kv_cache_layout
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# to free unused memory
del
key_caches
,
value_caches
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
if
implementation
==
"cuda"
:
function_under_test
=
lambda
:
ops
.
reshape_and_cache_flash
(
key
,
# noqa: F821
value
,
# noqa: F821
key_cache
,
# noqa: F821
value_cache
,
# noqa: F821
slot_mapping
,
# noqa: F821
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
function_under_test
=
lambda
:
triton_reshape_and_cache_flash
(
key
,
# noqa: F821
value
,
# noqa: F821
key_cache
,
# noqa: F821
value_cache
,
# noqa: F821
slot_mapping
,
# noqa: F821
kv_cache_dtype
,
k_scale
,
v_scale
,
)
if
benchmark_mode
==
"cudagraph"
:
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
function_under_test
()
torch
.
cuda
.
synchronize
()
function_under_test
=
lambda
:
g
.
replay
()
def
run_cuda_benchmark
(
n_iters
:
int
)
->
float
:
nonlocal
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
torch
.
cuda
.
synchronize
()
start
=
time
.
perf_counter
()
for
_
in
range
(
n_iters
):
function_under_test
()
torch
.
cuda
.
synchronize
()
end
=
time
.
perf_counter
()
return
(
end
-
start
)
/
n_iters
# warm-up
run_cuda_benchmark
(
3
)
lat
=
run_cuda_benchmark
(
num_iters
)
# free tensors to mitigate OOM when sweeping
del
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
torch
.
cuda
.
empty_cache
()
return
lat
def
main
(
args
):
rows
=
[]
for
layout
in
[
"NHD"
,
"HND"
]:
for
exp
in
range
(
1
,
17
):
n_tok
=
2
**
exp
lat
=
run_benchmark
(
num_tokens
=
n_tok
,
num_heads
=
args
.
num_heads
,
head_size
=
args
.
head_size
,
block_size
=
args
.
block_size
,
num_blocks
=
args
.
num_blocks
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
kv_cache_dtype
=
args
.
kv_cache_dtype
,
kv_cache_layout
=
layout
,
num_iters
=
args
.
iters
,
implementation
=
args
.
implementation
,
benchmark_mode
=
args
.
mode
,
device
=
"cuda"
,
)
rows
.
append
([
n_tok
,
layout
,
f
"
{
lat
*
1e6
:.
3
f
}
"
])
print
(
f
"Benchmark results for implementation
{
args
.
implementation
}
"
f
" (measuring with
{
args
.
mode
}
):"
)
print
(
tabulate
(
rows
,
headers
=
[
"num_tokens"
,
"layout"
,
"latency (µs)"
]))
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--num-blocks"
,
type
=
int
,
default
=
128
*
512
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"bfloat16"
,
)
parser
.
add_argument
(
"--kv-cache-dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8"
],
default
=
"auto"
,
)
parser
.
add_argument
(
"--iters"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--implementation"
,
type
=
str
,
choices
=
[
"cuda"
,
"triton"
],
default
=
"cuda"
,
)
parser
.
add_argument
(
"--mode"
,
type
=
str
,
choices
=
[
"cudagraph"
,
"no_graph"
],
default
=
"cudagraph"
,
)
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_rmsnorm.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
torch
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
from
torch
import
nn
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm.triton_utils
import
triton
class
HuggingFaceRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
)
->
None
:
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
if
residual
is
not
None
:
x
=
x
+
residual
.
to
(
torch
.
float32
)
residual
=
x
.
to
(
orig_dtype
)
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
.
to
(
orig_dtype
)
*
self
.
weight
if
residual
is
None
:
return
x
else
:
return
x
,
residual
def
rmsnorm_naive
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
eps
:
float
=
1e-6
,
):
naive_norm
=
HuggingFaceRMSNorm
(
x
.
shape
[
-
1
],
eps
=
eps
)
naive_norm
.
weight
=
nn
.
Parameter
(
weight
)
naive_norm
=
naive_norm
.
to
(
x
.
device
)
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
residual
is
not
None
:
residual
=
residual
.
view
(
-
1
,
residual
.
shape
[
-
1
])
output
=
naive_norm
(
x
,
residual
)
if
isinstance
(
output
,
tuple
):
output
=
(
output
[
0
].
view
(
orig_shape
),
output
[
1
].
view
(
orig_shape
))
else
:
output
=
output
.
view
(
orig_shape
)
return
output
def
rmsnorm_flashinfer
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
residual
is
not
None
:
residual
=
residual
.
view
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
is
not
None
:
fused_add_rmsnorm
(
x
,
residual
,
weight
,
eps
)
output
=
(
x
,
residual
)
else
:
output
=
rmsnorm
(
x
,
weight
,
eps
)
if
isinstance
(
output
,
tuple
):
output
=
(
output
[
0
].
view
(
orig_shape
),
output
[
1
].
view
(
orig_shape
))
else
:
output
=
output
.
view
(
orig_shape
)
return
output
def
rmsnorm_vllm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
residual
is
not
None
:
residual
=
residual
.
view
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
is
not
None
:
vllm_ops
.
fused_add_rms_norm
(
x
,
residual
,
weight
,
eps
)
output
=
(
x
,
residual
)
else
:
out
=
torch
.
empty_like
(
x
)
vllm_ops
.
rms_norm
(
out
,
x
,
weight
,
eps
)
output
=
out
if
isinstance
(
output
,
tuple
):
output
=
(
output
[
0
].
view
(
orig_shape
),
output
[
1
].
view
(
orig_shape
))
else
:
output
=
output
.
view
(
orig_shape
)
return
output
def
calculate_diff
(
batch_size
,
seq_len
,
hidden_size
,
use_residual
=
True
):
dtype
=
torch
.
bfloat16
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
output_naive
=
rmsnorm_naive
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
output_flashinfer
=
rmsnorm_flashinfer
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
output_vllm
=
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
if
use_residual
:
output_naive
=
output_naive
[
0
]
output_flashinfer
=
output_flashinfer
[
0
]
output_vllm
=
output_vllm
[
0
]
print
(
f
"Naive output=
{
output_naive
}
"
)
print
(
f
"FlashInfer output=
{
output_flashinfer
}
"
)
print
(
f
"vLLM output=
{
output_vllm
}
"
)
if
torch
.
allclose
(
output_naive
,
output_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
seq_length_range
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
head_num_range
=
[
32
,
48
]
configs
=
list
(
itertools
.
product
(
head_num_range
,
batch_size_range
,
seq_length_range
))
def
get_benchmark
(
use_residual
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"head_num"
,
"batch_size"
,
"seq_len"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"huggingface"
,
"flashinfer"
,
"vllm"
],
line_names
=
[
"HuggingFace"
,
"FlashInfer"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
f
"rmsnorm-perf-
{
'with'
if
use_residual
else
'without'
}
-residual"
,
args
=
{},
)
)
def
benchmark
(
head_num
,
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
bfloat16
hidden_size
=
head_num
*
128
# assuming head_dim = 128
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"huggingface"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rmsnorm_naive
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
),
quantiles
=
quantiles
,
)
elif
provider
==
"flashinfer"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rmsnorm_flashinfer
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size"
,
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
128
,
help
=
"Sequence length"
,
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
4096
,
help
=
"Hidden size (2nd dimension) of the sequence"
,
)
parser
.
add_argument
(
"--use-residual"
,
action
=
"store_true"
,
help
=
"Whether to use residual connection"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/rmsnorm/"
,
help
=
"Path to save rmsnorm benchmark results"
,
)
args
=
parser
.
parse_args
()
# Run correctness test
calculate_diff
(
batch_size
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
hidden_size
=
args
.
hidden_size
,
use_residual
=
args
.
use_residual
,
)
# Get the benchmark function with proper use_residual setting
benchmark
=
get_benchmark
(
args
.
use_residual
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
benchmarks/kernels/benchmark_rope.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
torch
from
vllm.benchmarks.lib.utils
import
default_vllm_config
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.triton_utils
import
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
8
,
2
)]
seq_len_range
=
[
2
**
i
for
i
in
range
(
6
,
10
,
1
)]
num_heads_range
=
[
32
,
48
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
num_heads_range
))
def
get_benchmark
(
head_size
,
rotary_dim
,
is_neox_style
,
device
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
,
"num_heads"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"torch"
,
"flashinfer"
,
"vllm"
],
line_names
=
[
"PyTorch"
,
"FlashInfer"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
f
"rope-perf
{
'-neox-style'
if
is_neox_style
else
''
}
"
,
args
=
{},
)
)
@
default_vllm_config
()
def
benchmark
(
batch_size
,
seq_len
,
num_heads
,
provider
):
dtype
=
torch
.
bfloat16
max_position
=
8192
rope_parameters
=
{
"partial_rotary_factor"
:
rotary_dim
/
head_size
}
rope
=
get_rope
(
head_size
,
max_position
,
is_neox_style
,
rope_parameters
)
rope
=
rope
.
to
(
dtype
=
dtype
,
device
=
device
)
cos_sin_cache
=
rope
.
cos_sin_cache
.
to
(
dtype
=
torch
.
float
,
device
=
device
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
device
=
device
)
query
=
torch
.
randn
(
(
batch_size
,
seq_len
,
num_heads
*
head_size
),
dtype
=
dtype
,
device
=
device
)
key
=
torch
.
randn_like
(
query
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rope
.
forward_native
(
positions
,
query
.
clone
(),
key
.
clone
()),
quantiles
=
quantiles
,
)
elif
provider
==
"flashinfer"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
ops
.
vllm
.
flashinfer_rotary_embedding
(
positions
,
query
.
clone
(),
key
.
clone
(),
head_size
,
cos_sin_cache
,
is_neox_style
,
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rope
.
forward_cuda
(
positions
,
query
.
clone
(),
key
.
clone
()),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the rotary embedding kernels."
)
parser
.
add_argument
(
"--is-neox-style"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--rotary-dim"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
32
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"bfloat16"
,
"float"
],
default
=
"float"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
choices
=
[
"cuda:0"
,
"cuda:1"
],
default
=
"cuda:0"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/rope/"
)
args
=
parser
.
parse_args
()
# Get the benchmark function
benchmark
=
get_benchmark
(
args
.
head_size
,
args
.
rotary_dim
,
args
.
is_neox_style
,
args
.
device
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
benchmarks/kernels/benchmark_shapes.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
WEIGHT_SHAPES
=
{
"ideal"
:
[[
4
*
256
*
32
,
256
*
32
]],
"mistralai/Mistral-7B-v0.1/TP1"
:
[
[
4096
,
6144
],
[
4096
,
4096
],
[
4096
,
28672
],
[
14336
,
4096
],
],
"mistralai/Mistral-7B-v0.1/TP2"
:
[
[
4096
,
3072
],
[
2048
,
4096
],
[
4096
,
14336
],
[
7168
,
4096
],
],
"mistralai/Mistral-7B-v0.1/TP4"
:
[
[
4096
,
1536
],
[
1024
,
4096
],
[
4096
,
7168
],
[
3584
,
4096
],
],
"meta-llama/Llama-2-7b-hf/TP1"
:
[
[
4096
,
12288
],
[
4096
,
4096
],
[
4096
,
22016
],
[
11008
,
4096
],
],
"meta-llama/Llama-2-7b-hf/TP2"
:
[
[
4096
,
6144
],
[
2048
,
4096
],
[
4096
,
11008
],
[
5504
,
4096
],
],
"meta-llama/Llama-2-7b-hf/TP4"
:
[
[
4096
,
3072
],
[
1024
,
4096
],
[
4096
,
5504
],
[
2752
,
4096
],
],
"meta-llama/Llama-2-13b-hf/TP1"
:
[
[
5120
,
15360
],
[
5120
,
5120
],
[
5120
,
27648
],
[
13824
,
5120
],
],
"meta-llama/Llama-2-13b-hf/TP2"
:
[
[
5120
,
7680
],
[
2560
,
5120
],
[
5120
,
13824
],
[
6912
,
5120
],
],
"meta-llama/Llama-2-13b-hf/TP4"
:
[
[
5120
,
3840
],
[
1280
,
5120
],
[
5120
,
6912
],
[
3456
,
5120
],
],
"meta-llama/Llama-2-70b-hf/TP1"
:
[
[
8192
,
10240
],
[
8192
,
8192
],
[
8192
,
57344
],
[
28672
,
8192
],
],
"meta-llama/Llama-2-70b-hf/TP2"
:
[
[
8192
,
5120
],
[
4096
,
8192
],
[
8192
,
28672
],
[
14336
,
8192
],
],
"meta-llama/Llama-2-70b-hf/TP4"
:
[
[
8192
,
2560
],
[
2048
,
8192
],
[
8192
,
14336
],
[
7168
,
8192
],
],
}
WEIGHT_SHAPES_MOE
=
{
"mistralai/Mixtral-8x7B-Instruct-v0.1"
:
[
[
8
,
2
,
4096
,
28672
],
[
8
,
2
,
14336
,
4096
],
],
"deepseek-ai/DeepSeek-V2-Lite"
:
[
[
64
,
6
,
2048
,
1408
],
],
"ibm-granite/granite-3.0-1b-a400m"
:
[
[
32
,
8
,
1024
,
1024
],
],
"ibm-granite/granite-3.0-3b-a800m"
:
[
[
40
,
8
,
1024
,
1536
],
],
}
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Comprehensive 3-way SiLU Benchmark Suite
This benchmark compares three SiLU implementations:
1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation
2. Triton Kernel - Triton-based implementation
The suite generates detailed performance comparisons including:
- Memory bandwidth utilization
- Speedup ratios (baseline vs optimized implementations)
- Performance across different expert configurations and token distributions
"""
from
collections.abc
import
Callable
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
torch
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
persistent_masked_m_silu_mul_quant
,
)
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
from
vllm.utils.torch_utils
import
set_random_seed
@
triton
.
jit
def
_silu_mul_fp8_quant_deep_gemm
(
# Pointers ------------------------------------------------------------
input_ptr
,
# 16-bit activations (E, T, 2*H)
y_q_ptr
,
# fp8 quantized activations (E, T, H)
y_s_ptr
,
# 16-bit scales (E, T, G)
counts_ptr
,
# int32 num tokens per expert (E)
# Sizes ---------------------------------------------------------------
H
:
tl
.
constexpr
,
# hidden dimension (per output)
GROUP_SIZE
:
tl
.
constexpr
,
# elements per group (usually 128)
# Strides for input (elements) ---------------------------------------
stride_i_e
,
stride_i_t
,
stride_i_h
,
# Strides for y_q (elements) -----------------------------------------
stride_yq_e
,
stride_yq_t
,
stride_yq_h
,
# Strides for y_s (elements) -----------------------------------------
stride_ys_e
,
stride_ys_t
,
stride_ys_g
,
# Stride for counts (elements)
stride_counts_e
,
# Numeric params ------------------------------------------------------
eps
:
tl
.
constexpr
,
fp8_min
:
tl
.
constexpr
,
fp8_max
:
tl
.
constexpr
,
use_ue8m0
:
tl
.
constexpr
,
# Meta ---------------------------------------------------------------
BLOCK
:
tl
.
constexpr
,
NUM_STAGES
:
tl
.
constexpr
,
):
G
=
H
//
GROUP_SIZE
# map program id -> (e, g)
pid
=
tl
.
program_id
(
0
)
e
=
pid
//
G
g
=
pid
%
G
e
=
e
.
to
(
tl
.
int64
)
g
=
g
.
to
(
tl
.
int64
)
# number of valid tokens for this expert
n_tokens
=
tl
.
load
(
counts_ptr
+
e
*
stride_counts_e
).
to
(
tl
.
int64
)
cols
=
tl
.
arange
(
0
,
BLOCK
).
to
(
tl
.
int64
)
mask
=
cols
<
BLOCK
base_input_offset
=
e
*
stride_i_e
+
g
*
GROUP_SIZE
*
stride_i_h
base_gate_offset
=
base_input_offset
+
cols
*
stride_i_h
base_up_offset
=
base_input_offset
+
H
*
stride_i_h
+
cols
*
stride_i_h
base_yq_offset
=
e
*
stride_yq_e
+
g
*
GROUP_SIZE
*
stride_yq_h
+
cols
*
stride_yq_h
base_ys_offset
=
e
*
stride_ys_e
+
g
*
stride_ys_g
for
t
in
tl
.
range
(
0
,
n_tokens
,
num_stages
=
NUM_STAGES
):
gate
=
tl
.
load
(
input_ptr
+
base_gate_offset
+
t
*
stride_i_t
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
up
=
tl
.
load
(
input_ptr
+
base_up_offset
+
t
*
stride_i_t
,
mask
=
mask
,
other
=
0.0
)
gate
=
gate
*
(
1.0
/
(
1.0
+
tl
.
exp
(
-
gate
)))
y
=
gate
*
up
y_s
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
/
fp8_max
if
use_ue8m0
:
y_s
=
tl
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
y_s
)))
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
base_yq_offset
+
t
*
stride_yq_t
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
+
base_ys_offset
+
t
*
stride_ys_t
,
y_s
)
def
silu_mul_fp8_quant_deep_gemm_triton
(
y
:
torch
.
Tensor
,
# (E, T, 2*H)
tokens_per_expert
:
torch
.
Tensor
,
# (E,) number of valid tokens per expert
num_parallel_tokens
,
group_size
:
int
=
128
,
eps
:
float
=
1e-10
,
expert_offsets
:
torch
.
Tensor
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.
Returns `(y_q, y_s)` where
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
"""
assert
y
.
ndim
==
3
,
"y must be (E, T, 2*H)"
E
,
T
,
H2
=
y
.
shape
assert
H2
%
2
==
0
,
"last dim of y must be even (2*H)"
H
=
H2
//
2
G
=
(
H
+
group_size
-
1
)
//
group_size
assert
H
%
group_size
==
0
,
"H must be divisible by group_size"
assert
tokens_per_expert
.
ndim
==
1
and
tokens_per_expert
.
shape
[
0
]
==
E
,
(
"tokens_per_expert must be shape (E,)"
)
tokens_per_expert
=
tokens_per_expert
.
to
(
device
=
y
.
device
,
dtype
=
torch
.
int32
)
# allocate outputs
fp8_dtype
=
torch
.
float8_e4m3fn
y_q
=
torch
.
empty
((
E
,
T
,
H
),
dtype
=
fp8_dtype
,
device
=
y
.
device
)
# strides (elements)
stride_i_e
,
stride_i_t
,
stride_i_h
=
y
.
stride
()
stride_yq_e
,
stride_yq_t
,
stride_yq_h
=
y_q
.
stride
()
# desired scale strides (elements): (T*G, 1, T)
stride_ys_e
=
T
*
G
stride_ys_t
=
1
stride_ys_g
=
T
y_s
=
torch
.
empty_strided
(
(
E
,
T
,
G
),
(
stride_ys_e
,
stride_ys_t
,
stride_ys_g
),
dtype
=
torch
.
float32
,
device
=
y
.
device
,
)
stride_cnt_e
=
tokens_per_expert
.
stride
()[
0
]
# Static grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid
=
(
E
*
G
,)
f_info
=
torch
.
finfo
(
fp8_dtype
)
fp8_max
=
f_info
.
max
fp8_min
=
f_info
.
min
_silu_mul_fp8_quant_deep_gemm
[
grid
](
y
,
y_q
,
y_s
,
tokens_per_expert
,
H
,
group_size
,
stride_i_e
,
stride_i_t
,
stride_i_h
,
stride_yq_e
,
stride_yq_t
,
stride_yq_h
,
stride_ys_e
,
stride_ys_t
,
stride_ys_g
,
stride_cnt_e
,
eps
,
fp8_min
,
fp8_max
,
is_deep_gemm_e8m0_used
(),
BLOCK
=
group_size
,
NUM_STAGES
=
4
,
num_warps
=
1
,
)
return
y_q
,
y_s
# Parse generation strategies
strategies
=
[
"random_imbalanced"
,
"uniform"
,
"max_t"
]
def
benchmark
(
kernel
:
Callable
,
E
:
int
,
T
:
int
,
H
:
int
,
total_tokens
:
int
,
num_parallel_tokens
:
int
=
64
,
G
:
int
=
128
,
runs
:
int
=
200
,
num_warmups
:
int
=
20
,
gen_strategy
:
str
=
"default"
,
iterations_per_run
:
int
=
20
,
):
def
generate_data
(
seed_offset
=
0
):
"""Generate input data with given seed offset"""
set_random_seed
(
42
+
seed_offset
)
y
=
torch
.
rand
((
E
,
T
,
2
*
H
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
if
gen_strategy
==
"random_imbalanced"
:
def
generate_expert_loads
(
n_e
,
total_tokens
,
ratio
,
device
=
"cuda"
):
mean
=
total_tokens
//
n_e
min_max
=
mean
//
ratio
e
=
torch
.
ones
(
size
=
(
E
,),
dtype
=
torch
.
int64
,
device
=
device
)
*
mean
e
[
0
]
=
min_max
r
=
torch
.
rand
(
size
=
(
E
-
1
,))
r
/=
r
.
sum
()
r
*=
total_tokens
-
min_max
r
=
r
.
round
().
long
()
e
[
1
:]
=
r
.
to
(
device
=
device
)
return
e
tokens_per_expert
=
generate_expert_loads
(
E
,
total_tokens
,
0.7
,
"cuda"
)
elif
gen_strategy
==
"uniform"
:
r
=
torch
.
rand
(
size
=
(
E
,))
r
/=
r
.
sum
()
r
*=
total_tokens
r
=
r
.
round
().
long
()
tokens_per_expert
=
r
elif
gen_strategy
==
"max_t"
:
tokens_per_expert
=
torch
.
empty
(
size
=
(
E
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
tokens_per_expert
.
fill_
(
total_tokens
/
E
)
elif
gen_strategy
==
"first_t"
:
tokens_per_expert
=
torch
.
zeros
(
size
=
(
E
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
tokens_per_expert
[
0
]
=
min
(
T
,
total_tokens
)
else
:
raise
ValueError
(
f
"Unknown generation strategy:
{
gen_strategy
}
"
)
return
y
,
tokens_per_expert
dataset_count
=
4
# Pre-generate different input matrices for each iteration to avoid cache effects
data_sets
=
[
generate_data
(
i
)
for
i
in
range
(
dataset_count
)]
# Warmup
y
,
tokens_per_expert
=
data_sets
[
0
]
for
_
in
range
(
num_warmups
):
kernel
(
y
,
tokens_per_expert
,
num_parallel_tokens
=
num_parallel_tokens
,
group_size
=
G
)
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
# Benchmark
latencies
:
list
[
float
]
=
[]
for
_
in
range
(
runs
):
torch
.
cuda
.
synchronize
()
start_event
.
record
()
for
i
in
range
(
iterations_per_run
):
y
,
tokens_per_expert
=
data_sets
[
i
%
dataset_count
]
kernel
(
y
,
tokens_per_expert
,
num_parallel_tokens
=
num_parallel_tokens
,
group_size
=
G
,
)
end_event
.
record
()
end_event
.
synchronize
()
total_time_ms
=
start_event
.
elapsed_time
(
end_event
)
per_iter_time_ms
=
total_time_ms
/
iterations_per_run
latencies
.
append
(
per_iter_time_ms
)
# Use median instead of average for better outlier handling
median_time_ms
=
np
.
median
(
latencies
)
median_time_s
=
median_time_ms
/
1000
# Calculate actual work done (using first dataset for consistency)
_
,
tokens_per_expert
=
data_sets
[
0
]
actual_tokens
=
tokens_per_expert
.
sum
().
item
()
actual_elements
=
actual_tokens
*
H
# GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
ops_per_element
=
8
total_ops
=
actual_elements
*
ops_per_element
gflops
=
total_ops
/
median_time_s
/
1e9
# Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
input_bytes
=
actual_tokens
*
2
*
H
*
2
# 2*H bfloat16 inputs
output_bytes
=
actual_tokens
*
H
*
1
# H fp8 outputs
scale_bytes
=
actual_tokens
*
(
H
//
G
)
*
4
# scales in float32
total_bytes
=
input_bytes
+
output_bytes
+
scale_bytes
memory_bw
=
total_bytes
/
median_time_s
/
1e9
HOPPER_BANDWIDTH_TBPS
=
3.35
return
(
median_time_ms
,
gflops
,
memory_bw
,
(
memory_bw
/
(
HOPPER_BANDWIDTH_TBPS
*
1024
))
*
100
,
)
def
create_comparison_plot
(
ratios
,
silu_v2_times
,
triton_times
,
config_labels
,
strategy_name
,
id
):
fig
,
ax
=
plt
.
subplots
(
1
,
1
,
figsize
=
(
18
,
6
))
# Configure x-axis positions
x
=
np
.
arange
(
len
(
config_labels
))
width
=
0.25
# Execution Time plot (lower is better)
ax
.
bar
(
x
,
silu_v2_times
,
width
,
label
=
"SiLU V2 (CUDA)"
,
alpha
=
0.8
,
color
=
"blue"
)
ax
.
bar
(
x
+
width
,
triton_times
,
width
,
label
=
"Triton Kernel"
,
alpha
=
0.8
,
color
=
"green"
)
# Add speedup labels over each bar trio
for
i
in
range
(
len
(
x
)):
triton_v2_speedup
=
ratios
[
i
][
1
]
# triton/v2
max_height
=
max
(
silu_v2_times
[
i
],
triton_times
[
i
])
# Triton/V2 speedup
ax
.
text
(
x
[
i
]
+
width
/
2
,
max_height
+
max_height
*
0.02
,
f
"
{
triton_v2_speedup
:.
2
f
}
x"
,
ha
=
"center"
,
va
=
"bottom"
,
fontweight
=
"bold"
,
fontsize
=
8
,
)
ax
.
set_xlabel
(
"Configuration"
)
ax
.
set_ylabel
(
"% Utilization"
)
ax
.
set_title
(
f
"Memory Bandwidth Utilization (%) -
{
strategy_name
}
\n
(Higher is Better)"
)
ax
.
set_xticks
(
x
)
ax
.
set_xticklabels
(
config_labels
,
rotation
=
45
,
ha
=
"right"
)
ax
.
legend
()
ax
.
grid
(
True
,
alpha
=
0.3
)
plt
.
tight_layout
()
return
fig
,
ax
def
create_combined_plot
(
all_results
):
num_strategies
=
len
(
all_results
)
fig
,
axes
=
plt
.
subplots
(
num_strategies
,
1
,
figsize
=
(
22
,
7
*
num_strategies
))
if
num_strategies
==
1
:
axes
=
[
axes
]
for
idx
,
(
strategy_name
,
all_ratios
,
all_silu_v2_results
,
all_triton_results
,
config_labels
,
config_x_axis
,
)
in
enumerate
(
all_results
):
ax
=
axes
[
idx
]
# Flatten the nested results to get bandwidth percentages for plotting
silu_v2_bandwidths
=
[]
triton_bandwidths
=
[]
flat_ratios
=
[]
for
config_results
in
all_silu_v2_results
:
for
result
in
config_results
:
silu_v2_bandwidths
.
append
(
result
[
3
])
# bandwidth percentage
for
config_results
in
all_triton_results
:
for
result
in
config_results
:
triton_bandwidths
.
append
(
result
[
3
])
# bandwidth percentage
for
config_ratios
in
all_ratios
:
for
ratio
in
config_ratios
:
flat_ratios
.
append
(
ratio
)
# Configure x-axis positions
x
=
np
.
arange
(
len
(
config_labels
))
width
=
0.25
# Bandwidth utilization plot (higher is better)
ax
.
bar
(
x
,
silu_v2_bandwidths
,
width
,
label
=
"SiLU V2 (CUDA)"
,
alpha
=
0.8
,
color
=
"blue"
,
)
ax
.
bar
(
x
+
width
,
triton_bandwidths
,
width
,
label
=
"Triton Kernel"
,
alpha
=
0.8
,
color
=
"green"
,
)
# Add speedup labels over each bar trio
for
i
in
range
(
len
(
x
)):
triton_v2_speedup
=
flat_ratios
[
i
]
# triton/v2
max_height
=
max
(
silu_v2_bandwidths
[
i
],
triton_bandwidths
[
i
])
# Triton/V2 speedup
ax
.
text
(
x
[
i
]
+
width
/
2
,
max_height
+
max_height
*
0.02
,
f
"
{
triton_v2_speedup
:.
2
f
}
x"
,
ha
=
"center"
,
va
=
"bottom"
,
fontweight
=
"bold"
,
fontsize
=
8
,
)
ax
.
set_xlabel
(
"Configuration"
)
ax
.
set_ylabel
(
"% Utilization"
)
ax
.
set_title
(
f
"Memory Bandwidth Utilization (%) -
{
strategy_name
}
\n
(Higher is Better)"
)
ax
.
set_xticks
(
x
)
ax
.
set_xticklabels
(
config_labels
,
rotation
=
45
,
ha
=
"right"
)
ax
.
legend
()
ax
.
grid
(
True
,
alpha
=
0.3
)
plt
.
tight_layout
()
filename
=
"silu_benchmark_combined_3way.png"
plt
.
savefig
(
filename
,
dpi
=
300
,
bbox_inches
=
"tight"
)
plt
.
show
()
return
filename
outer_dim
=
7168
configs
=
[
# DeepSeekV3 Configs
# (1, 56, 7168),
(
8
,
1024
,
7168
),
# (32, 56, 7168),
# DeepSeekV3 Configs
(
32
,
1024
,
7168
),
# DeepSeekV3 Configs
(
256
,
1024
,
7168
),
]
runs
=
100
num_warmups
=
20
strategy_descriptions
=
{
"uniform"
:
"Uniform Random"
,
"random_imbalanced"
:
"Imbalanced Random"
,
"max_t"
:
"Even Assignment"
,
"first_t"
:
"experts[0] = T, experts[1:] = 0"
,
}
print
(
f
"GPU:
{
torch
.
cuda
.
get_device_name
()
}
"
)
print
(
f
"Testing strategies:
{
', '
.
join
(
strategies
)
}
"
)
print
(
f
"Configurations:
{
len
(
configs
)
}
configs"
)
all_results
=
[]
# Run benchmarks for each strategy
for
id
,
strategy
in
enumerate
(
strategies
):
print
(
f
"
\n
{
'='
*
60
}
"
)
print
(
f
"Testing strategy:
{
strategy_descriptions
[
strategy
]
}
"
)
print
(
f
"
{
'='
*
60
}
"
)
# Collect benchmark data for all three algorithms
config_labels
=
[]
config_x_axis
=
[]
all_silu_v2_results
=
[]
all_triton_results
=
[]
all_ratios
=
[]
for
E
,
T
,
H
in
configs
:
total_tokens_config
=
[]
for
i
in
[
8
,
16
,
32
,
64
,
128
,
256
,
512
]:
if
i
<=
T
:
total_tokens_config
.
append
(
i
*
E
)
config_x_axis
.
append
(
total_tokens_config
)
silu_v2_results
=
[]
triton_results
=
[]
ratios
=
[]
for
total_tokens
in
total_tokens_config
:
config_label
=
f
"E=
{
E
}
,T=
{
T
}
,H=
{
H
}
,TT=
{
total_tokens
}
"
config_labels
.
append
(
config_label
)
# SiLU V2 (CUDA kernel) results
time_ms_silu_v2
,
gflops
,
gbps
,
perc
=
benchmark
(
persistent_masked_m_silu_mul_quant
,
E
,
T
,
H
,
total_tokens
,
runs
=
runs
,
num_warmups
=
num_warmups
,
gen_strategy
=
strategy
,
)
silu_v2_results
.
append
((
time_ms_silu_v2
,
gflops
,
gbps
,
perc
))
# Triton kernel results
time_ms_triton
,
gflops
,
gbps
,
perc
=
benchmark
(
silu_mul_fp8_quant_deep_gemm_triton
,
E
,
T
,
H
,
total_tokens
,
runs
=
runs
,
num_warmups
=
num_warmups
,
gen_strategy
=
strategy
,
)
triton_results
.
append
((
time_ms_triton
,
gflops
,
gbps
,
perc
))
# Calculate speedup ratios (triton baseline / implementation)
triton_v2_ratio
=
time_ms_triton
/
time_ms_silu_v2
ratios
.
append
(
triton_v2_ratio
)
print
(
f
"Completed:
{
config_label
}
:"
f
" V2:
{
time_ms_silu_v2
:.
3
f
}
ms,"
f
" Triton:
{
time_ms_triton
:.
3
f
}
ms"
)
all_silu_v2_results
.
append
(
silu_v2_results
)
all_triton_results
.
append
(
triton_results
)
all_ratios
.
append
(
ratios
)
# Store results for combined plotting
all_results
.
append
(
(
strategy_descriptions
[
strategy
],
all_ratios
,
all_silu_v2_results
,
all_triton_results
,
config_labels
,
config_x_axis
,
)
)
# Print summary table for this strategy
print
(
f
"
\n
Summary Table -
{
strategy_descriptions
[
strategy
]
}
:"
)
print
(
f
"
{
'V2 Time(ms)'
:
<
12
}
{
'Triton Time(ms)'
:
<
14
}
{
'Triton/V2'
:
<
10
}
"
)
print
(
"-"
*
90
)
for
i
,
(
E
,
T
,
H
)
in
enumerate
(
configs
):
# Get the first result for each config (simplifying for summary)
v2_time
=
silu_v2_results
[
i
][
0
]
triton_time
=
triton_results
[
i
][
0
]
triton_v2_speedup
=
triton_time
/
v2_time
config_label
=
f
"E=
{
E
:
3
d
}
,T=
{
T
:
4
d
}
,H=
{
H
:
4
d
}
"
print
(
f
"
{
config_label
:
<
20
}
{
v2_time
:
8.5
f
}
{
triton_time
:
10.5
f
}
"
f
"
{
triton_v2_speedup
:
8.2
f
}
x"
)
def
create_total_tokens_plot
(
all_results
):
num_strategies
=
len
(
all_results
)
num_configs
=
len
(
configs
)
fig
,
axs
=
plt
.
subplots
(
num_strategies
,
num_configs
*
2
,
figsize
=
(
32
,
8
*
num_strategies
)
)
# Add main title to the entire figure
fig
.
suptitle
(
"Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)"
,
fontsize
=
18
,
fontweight
=
"bold"
,
y
=
0.98
,
)
# Handle single strategy case
if
num_strategies
==
1
:
axs
=
axs
.
reshape
(
1
,
-
1
)
# Handle single config case
if
num_configs
==
1
:
axs
=
axs
.
reshape
(
-
1
,
2
)
for
strategy_idx
,
result
in
enumerate
(
all_results
):
(
strategy_name
,
all_ratios
,
all_silu_v2_results
,
all_triton_results
,
config_labels
,
config_x_axis
,
)
=
result
for
config_idx
in
range
(
num_configs
):
# Speedup plot (left column)
ax_speedup
=
axs
[
strategy_idx
,
config_idx
*
2
]
# Bandwidth plot (right column)
ax_bandwidth
=
axs
[
strategy_idx
,
config_idx
*
2
+
1
]
E
,
T
,
H
=
configs
[
config_idx
]
ratios
=
all_ratios
[
config_idx
]
total_tokens_values
=
config_x_axis
[
config_idx
]
# Extract speedup ratios
triton_v2_ratios
=
[
ratio
for
ratio
in
ratios
]
# Extract bandwidth percentages for all implementations
v2_bandwidth_percentages
=
[
result
[
3
]
for
result
in
all_silu_v2_results
[
config_idx
]
]
triton_bandwidth_percentages
=
[
result
[
3
]
for
result
in
all_triton_results
[
config_idx
]
]
# Plot speedup ratios vs total tokens (left plot)
ax_speedup
.
plot
(
total_tokens_values
,
triton_v2_ratios
,
"go-"
,
linewidth
=
3
,
markersize
=
8
,
label
=
"Triton/V2 Speedup"
,
)
ax_speedup
.
set_title
(
f
"
{
strategy_name
}
\n
Speedup vs Baseline (Triton)
\n
E=
{
E
}
, T=
{
T
}
, H=
{
H
}
"
,
fontsize
=
12
,
fontweight
=
"bold"
,
)
ax_speedup
.
set_xlabel
(
"Total Tokens"
,
fontweight
=
"bold"
,
fontsize
=
11
)
ax_speedup
.
set_ylabel
(
"Speedup Ratio"
,
fontweight
=
"bold"
,
fontsize
=
11
)
ax_speedup
.
legend
(
prop
=
{
"weight"
:
"bold"
})
ax_speedup
.
grid
(
True
,
alpha
=
0.3
)
# Plot bandwidth utilization (right plot)
ax_bandwidth
.
plot
(
total_tokens_values
,
v2_bandwidth_percentages
,
"o-"
,
linewidth
=
3
,
markersize
=
8
,
label
=
"SiLU V2"
,
color
=
"blue"
,
)
ax_bandwidth
.
plot
(
total_tokens_values
,
triton_bandwidth_percentages
,
"o-"
,
linewidth
=
3
,
markersize
=
8
,
label
=
"Triton"
,
color
=
"green"
,
)
ax_bandwidth
.
set_title
(
f
"
{
strategy_name
}
\n
Bandwidth Utilization (Hopper)
\n
E=
{
E
}
, T=
{
T
}
, H=
{
H
}
"
,
fontsize
=
12
,
fontweight
=
"bold"
,
)
ax_bandwidth
.
set_xlabel
(
"Total Tokens"
,
fontweight
=
"bold"
,
fontsize
=
11
)
ax_bandwidth
.
set_ylabel
(
"% of Peak Bandwidth"
,
fontweight
=
"bold"
,
fontsize
=
11
)
ax_bandwidth
.
legend
(
prop
=
{
"weight"
:
"bold"
})
ax_bandwidth
.
grid
(
True
,
alpha
=
0.3
)
# Format x-axis labels for both plots
for
ax
in
[
ax_speedup
,
ax_bandwidth
]:
ax
.
set_xticks
(
total_tokens_values
)
ax
.
set_xticklabels
(
[
f
"
{
tt
//
1000
}
K"
if
tt
>=
1000
else
str
(
tt
)
for
tt
in
total_tokens_values
],
fontweight
=
"bold"
,
)
# Make tick labels bold
for
label
in
ax
.
get_xticklabels
()
+
ax
.
get_yticklabels
():
label
.
set_fontweight
(
"bold"
)
# Add value labels on Triton/V2 speedup points
for
x
,
y
in
zip
(
total_tokens_values
,
triton_v2_ratios
):
ax_speedup
.
annotate
(
f
"
{
y
:.
2
f
}
x"
,
(
x
,
y
),
textcoords
=
"offset points"
,
xytext
=
(
0
,
-
15
),
ha
=
"center"
,
fontsize
=
9
,
fontweight
=
"bold"
,
bbox
=
dict
(
boxstyle
=
"round,pad=0.2"
,
facecolor
=
"green"
,
alpha
=
0.3
),
)
plt
.
tight_layout
()
plt
.
subplots_adjust
(
top
=
0.93
)
# Make room for main title
filename
=
"silu_benchmark_total_tokens_3way.png"
plt
.
savefig
(
filename
,
dpi
=
300
,
bbox_inches
=
"tight"
)
plt
.
show
()
return
filename
# Create comprehensive 3-way comparison plots
combined_plot_filename
=
create_combined_plot
(
all_results
)
total_tokens_plot_filename
=
create_total_tokens_plot
(
all_results
)
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
"3-Way Benchmark Suite Complete!"
)
print
(
f
"Generated combined comparison plot:
{
combined_plot_filename
}
"
)
print
(
f
"Generated total tokens analysis plot:
{
total_tokens_plot_filename
}
"
)
print
(
"Compared: SiLU V2 (CUDA), and Triton implementations"
)
print
(
f
"
{
'='
*
80
}
"
)
benchmarks/kernels/benchmark_trtllm_decode_attention.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
csv
import
os
from
datetime
import
datetime
import
flashinfer
import
torch
from
vllm.utils.math_utils
import
round_up
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
FP8_DTYPE
=
torch
.
float8_e4m3fn
FP4_DTYPE
=
torch
.
uint8
def
to_float8
(
x
,
dtype
=
torch
.
float8_e4m3fn
):
finfo
=
torch
.
finfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
scale
=
finfo
.
max
/
amax
*
0.1
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
x_scl_sat
.
to
(
dtype
),
scale
.
float
().
reciprocal
()
@
torch
.
no_grad
()
def
benchmark_decode
(
dtype
:
torch
.
dtype
,
quant_dtypes
:
tuple
[
torch
.
dtype
|
None
,
torch
.
dtype
|
None
,
torch
.
dtype
|
None
],
batch_size
:
int
,
max_seq_len
:
int
,
num_heads
:
tuple
[
int
,
int
]
=
(
64
,
8
),
head_size
:
int
=
128
,
kv_layout
:
str
=
"HND"
,
block_size
:
int
=
16
,
warmup
:
int
=
10
,
trials
:
int
=
20
,
):
torch
.
set_default_device
(
"cuda"
)
torch
.
manual_seed
(
0
)
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtypes
q_quant_dtype
=
q_quant_dtype
or
dtype
kv_quant_dtype
=
kv_quant_dtype
or
dtype
o_quant_dtype
=
o_quant_dtype
or
dtype
num_qo_heads
,
num_kv_heads
=
num_heads
assert
num_qo_heads
%
num_kv_heads
==
0
sm_scale
=
float
(
1.0
/
(
head_size
**
0.5
))
# large number to reduce kv_cache reuse
NUM_BLOCKS
=
int
(
256000
/
block_size
)
kv_cache_shape
=
None
if
kv_layout
==
"NHD"
:
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
block_size
,
num_kv_heads
,
head_size
)
elif
kv_layout
==
"HND"
:
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
num_kv_heads
,
block_size
,
head_size
)
else
:
raise
ValueError
(
f
"Invalid kv_layout:
{
kv_layout
}
"
)
# Always using 1.0 scale to reflect the real perf in benchmarking
q_scale
=
1.0
ref_query
=
torch
.
randn
(
batch_size
,
num_qo_heads
,
head_size
,
dtype
=
dtype
)
if
q_quant_dtype
==
FP8_DTYPE
:
query
,
_
=
to_float8
(
ref_query
)
else
:
query
=
ref_query
kv_lens
=
torch
.
randint
(
1
,
max_seq_len
,
(
batch_size
,),
dtype
=
torch
.
int32
)
kv_lens
[
-
1
]
=
max_seq_len
seq_lens
=
kv_lens
max_seq_len
=
torch
.
max
(
seq_lens
).
item
()
# Always using 1.0 scale to reflect the real perf in benchmarking
k_scale
=
v_scale
=
1.0
ref_kv_cache
=
torch
.
randn
(
kv_cache_shape
,
dtype
=
dtype
)
if
kv_quant_dtype
==
FP8_DTYPE
:
kv_cache
,
_
=
to_float8
(
ref_kv_cache
)
else
:
kv_cache
=
ref_kv_cache
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
batch_size
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
batch_size
):
seq_len
=
seq_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indptr
.
append
(
kv_indptr
[
-
1
]
+
num_blocks
)
kv_last_page_len
=
seq_len
%
block_size
if
kv_last_page_len
==
0
:
kv_last_page_len
=
block_size
kv_last_page_lens
.
append
(
kv_last_page_len
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
zeros
(
1024
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
kv_layout
,
use_tensor_cores
=
True
,
)
wrapper
.
plan
(
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_qo_heads
,
num_kv_heads
,
head_size
,
block_size
,
"NONE"
,
sm_scale
=
sm_scale
,
q_data_type
=
dtype
,
kv_data_type
=
dtype
,
)
def
time_fn
(
fn
,
warmup
=
10
,
trials
=
20
):
torch
.
cuda
.
synchronize
()
start
=
torch
.
Event
(
enable_timing
=
True
)
end
=
torch
.
Event
(
enable_timing
=
True
)
times
=
[]
for
i
in
range
(
warmup
):
fn
()
for
i
in
range
(
trials
):
start
.
record
()
fn
()
end
.
record
()
torch
.
cuda
.
synchronize
()
times
.
append
(
start
.
elapsed_time
(
end
))
# ms
return
sum
(
times
)
/
len
(
times
),
torch
.
std
(
torch
.
tensor
(
times
))
o_scale
=
1.0
o_sf_scale
=
None
output_baseline
=
torch
.
empty
(
ref_query
.
shape
,
dtype
=
dtype
)
if
o_quant_dtype
==
FP4_DTYPE
:
o_sf_scale
=
500.0
output_trtllm
=
flashinfer
.
utils
.
FP4Tensor
(
torch
.
empty
(
query
.
shape
[:
-
1
]
+
(
query
.
shape
[
-
1
]
//
2
,),
dtype
=
torch
.
uint8
),
torch
.
empty
(
(
round_up
(
query
.
shape
[
0
],
128
),
round_up
(
query
.
shape
[
1
]
*
query
.
shape
[
2
]
//
16
,
4
),
),
dtype
=
torch
.
float8_e4m3fn
,
),
)
else
:
output_trtllm
=
torch
.
empty
(
query
.
shape
,
dtype
=
o_quant_dtype
)
def
baseline_decode
():
return
wrapper
.
run
(
ref_query
,
ref_kv_cache
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
out
=
output_baseline
,
)
def
trtllm_decode
():
return
flashinfer
.
decode
.
trtllm_batch_decode_with_kv_cache
(
query
=
query
,
kv_cache
=
kv_cache
,
workspace_buffer
=
workspace_buffer
,
block_tables
=
block_tables
,
seq_lens
=
seq_lens
,
max_seq_len
=
max_seq_len
,
bmm1_scale
=
q_scale
*
k_scale
*
sm_scale
,
bmm2_scale
=
v_scale
/
o_scale
,
o_sf_scale
=
o_sf_scale
,
out
=
output_trtllm
,
)
baseline_mean
,
baseline_std
=
time_fn
(
baseline_decode
)
trtllm_mean
,
trtllm_std
=
time_fn
(
trtllm_decode
)
# Calculate percentage speedup (positive means TRT is faster)
speedup_percent
=
(
baseline_mean
-
trtllm_mean
)
/
baseline_mean
print
(
f
"
\t
{
batch_size
}
\t
{
max_seq_len
}
\t
{
trtllm_mean
:.
3
f
}
\t
{
trtllm_std
.
item
():.
3
f
}
"
f
"
\t
{
baseline_mean
:.
3
f
}
\t
{
baseline_std
.
item
():.
3
f
}
\t
{
speedup_percent
:.
3
f
}
"
)
# Return results for CSV writing
return
{
"batch_size"
:
batch_size
,
"trtllm_mean"
:
trtllm_mean
,
"trtllm_std"
:
trtllm_std
.
item
(),
"baseline_mean"
:
baseline_mean
,
"baseline_std"
:
baseline_std
.
item
(),
"speedup_percent"
:
speedup_percent
,
"q_dtype"
:
str
(
q_quant_dtype
),
"kv_cache_dtype"
:
str
(
kv_quant_dtype
),
"output_dtype"
:
str
(
o_quant_dtype
),
"block_size"
:
block_size
,
"num_kv_heads"
:
num_kv_heads
,
"head_size"
:
head_size
,
"max_seq_len"
:
max_seq_len
,
}
def
write_results_to_csv
(
results
,
filename
=
None
):
"""Write benchmark results to CSV file."""
if
filename
is
None
:
timestamp
=
datetime
.
now
().
strftime
(
"%Y%m%d_%H%M%S"
)
filename
=
f
"flashinfer_trtllm_benchmark_
{
timestamp
}
.csv"
fieldnames
=
[
"batch_size"
,
"trtllm_mean"
,
"trtllm_std"
,
"baseline_mean"
,
"baseline_std"
,
"speedup_percent"
,
"q_dtype"
,
"kv_cache_dtype"
,
"output_dtype"
,
"block_size"
,
"num_kv_heads"
,
"head_size"
,
"max_seq_len"
,
]
file_exists
=
os
.
path
.
exists
(
filename
)
with
open
(
filename
,
"a"
,
newline
=
""
)
as
csvfile
:
writer
=
csv
.
DictWriter
(
csvfile
,
fieldnames
=
fieldnames
)
if
not
file_exists
:
writer
.
writeheader
()
for
result
in
results
:
writer
.
writerow
(
result
)
print
(
f
"Results written to
{
filename
}
"
)
if
__name__
==
"__main__"
:
batch_sizes
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
]
max_seq_lens
=
[
1024
,
2048
,
4096
,
8192
,
16384
,
32768
,
65536
,
131072
]
all_results
=
[]
dtype
=
torch
.
bfloat16
quant_dtypes
=
[
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(
None
,
None
,
None
),
(
None
,
FP8_DTYPE
,
None
),
(
FP8_DTYPE
,
FP8_DTYPE
,
None
),
(
FP8_DTYPE
,
FP8_DTYPE
,
FP8_DTYPE
),
(
FP8_DTYPE
,
FP8_DTYPE
,
FP4_DTYPE
),
]
for
quant_dtype
in
quant_dtypes
:
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtype
q_quant_dtype
=
q_quant_dtype
or
dtype
kv_quant_dtype
=
kv_quant_dtype
or
dtype
o_quant_dtype
=
o_quant_dtype
or
dtype
print
(
f
"Running benchmark for q_dtype =
{
q_quant_dtype
}
, "
f
"kv_cache_dtype:
{
kv_quant_dtype
}
, "
f
"output_dtype:
{
o_quant_dtype
}
"
)
print
(
"
\t
batch_size
\t
max_seq_len
\t
trtllm_mean
\t
trtllm_std
\t
baseline_mean
\t
"
"baseline_std
\t
speedup_percent"
)
for
max_seq_len
in
max_seq_lens
:
for
bs
in
batch_sizes
:
result
=
benchmark_decode
(
dtype
=
dtype
,
quant_dtypes
=
quant_dtype
,
batch_size
=
bs
,
max_seq_len
=
max_seq_len
,
)
all_results
.
append
(
result
)
# Write all results to CSV
write_results_to_csv
(
all_results
)
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
csv
import
os
from
datetime
import
datetime
import
flashinfer
import
torch
from
vllm.utils.math_utils
import
round_up
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
FP8_DTYPE
=
torch
.
float8_e4m3fn
FP4_DTYPE
=
torch
.
uint8
def
to_float8
(
x
,
dtype
=
torch
.
float8_e4m3fn
):
finfo
=
torch
.
finfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
scale
=
finfo
.
max
/
amax
*
0.1
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
x_scl_sat
.
to
(
dtype
),
scale
.
float
().
reciprocal
()
@
torch
.
no_grad
()
def
benchmark_prefill
(
dtype
:
torch
.
dtype
,
quant_dtypes
:
tuple
[
torch
.
dtype
|
None
,
torch
.
dtype
|
None
,
torch
.
dtype
|
None
],
batch_size
:
int
,
max_seq_len
:
int
,
num_heads
:
tuple
[
int
,
int
]
=
(
64
,
8
),
head_size
:
int
=
128
,
kv_layout
:
str
=
"HND"
,
block_size
:
int
=
16
,
warmup
:
int
=
10
,
trials
:
int
=
20
,
):
torch
.
set_default_device
(
"cuda"
)
torch
.
manual_seed
(
0
)
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtypes
q_quant_dtype
=
q_quant_dtype
or
dtype
kv_quant_dtype
=
kv_quant_dtype
or
dtype
o_quant_dtype
=
o_quant_dtype
or
dtype
max_q_len
=
max_kv_len
=
max_seq_len
num_qo_heads
,
num_kv_heads
=
num_heads
assert
num_qo_heads
%
num_kv_heads
==
0
sm_scale
=
float
(
1.0
/
(
head_size
**
0.5
))
# large number to reduce kv_cache reuse
NUM_BLOCKS
=
int
(
256000
/
block_size
)
kv_cache_shape
=
None
if
kv_layout
==
"NHD"
:
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
block_size
,
num_kv_heads
,
head_size
)
elif
kv_layout
==
"HND"
:
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
num_kv_heads
,
block_size
,
head_size
)
else
:
raise
ValueError
(
f
"Invalid kv_layout:
{
kv_layout
}
"
)
q_lens
=
torch
.
randint
(
1
,
max_q_len
,
(
batch_size
,),
dtype
=
torch
.
int32
)
q_lens
[
-
1
]
=
max_q_len
q_indptr
=
torch
.
cat
(
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
torch
.
cumsum
(
q_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
]
)
# Always using 1.0 scale to reflect the real perf in benchmarking
q_scale
=
1.0
ref_query
=
torch
.
randn
(
torch
.
sum
(
q_lens
).
item
(),
num_qo_heads
,
head_size
,
dtype
=
dtype
)
if
q_quant_dtype
==
FP8_DTYPE
:
query
,
_
=
to_float8
(
ref_query
)
else
:
query
=
ref_query
kv_lens
=
torch
.
randint
(
0
,
max_kv_len
,
(
batch_size
,),
dtype
=
torch
.
int32
)
kv_lens
[
-
1
]
=
max_kv_len
seq_lens
=
kv_lens
+
q_lens
max_seq_len
=
torch
.
max
(
seq_lens
).
item
()
# Always using 1.0 scale to reflect the real perf in benchmarking
k_scale
=
v_scale
=
1.0
ref_kv_cache
=
torch
.
randn
(
kv_cache_shape
,
dtype
=
dtype
)
if
kv_quant_dtype
==
FP8_DTYPE
:
kv_cache
,
_
=
to_float8
(
ref_kv_cache
)
else
:
kv_cache
=
ref_kv_cache
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
batch_size
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
batch_size
):
seq_len
=
seq_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indptr
.
append
(
kv_indptr
[
-
1
]
+
num_blocks
)
kv_last_page_len
=
seq_len
%
block_size
if
kv_last_page_len
==
0
:
kv_last_page_len
=
block_size
kv_last_page_lens
.
append
(
kv_last_page_len
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
zeros
(
1024
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
kv_layout
)
wrapper
.
plan
(
q_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_qo_heads
,
num_kv_heads
,
head_size
,
block_size
,
causal
=
True
,
sm_scale
=
sm_scale
,
q_data_type
=
dtype
,
kv_data_type
=
dtype
,
)
def
time_fn
(
fn
,
warmup
=
10
,
trials
=
20
):
torch
.
cuda
.
synchronize
()
start
=
torch
.
Event
(
enable_timing
=
True
)
end
=
torch
.
Event
(
enable_timing
=
True
)
times
=
[]
for
i
in
range
(
warmup
):
fn
()
for
i
in
range
(
trials
):
start
.
record
()
fn
()
end
.
record
()
torch
.
cuda
.
synchronize
()
times
.
append
(
start
.
elapsed_time
(
end
))
# ms
return
sum
(
times
)
/
len
(
times
),
torch
.
std
(
torch
.
tensor
(
times
))
o_scale
=
1.0
o_sf_scale
=
None
output_baseline
=
torch
.
empty
(
ref_query
.
shape
,
dtype
=
dtype
)
if
o_quant_dtype
==
FP4_DTYPE
:
o_sf_scale
=
500.0
output_trtllm
=
flashinfer
.
utils
.
FP4Tensor
(
torch
.
empty
(
query
.
shape
[:
-
1
]
+
(
query
.
shape
[
-
1
]
//
2
,),
dtype
=
torch
.
uint8
),
torch
.
empty
(
(
round_up
(
query
.
shape
[
0
],
128
),
round_up
(
query
.
shape
[
1
]
*
query
.
shape
[
2
]
//
16
,
4
),
),
dtype
=
torch
.
float8_e4m3fn
,
),
)
else
:
output_trtllm
=
torch
.
empty
(
query
.
shape
,
dtype
=
o_quant_dtype
)
def
baseline_prefill
():
return
wrapper
.
run
(
ref_query
,
ref_kv_cache
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
out
=
output_baseline
,
)
def
trtllm_prefill
():
return
flashinfer
.
prefill
.
trtllm_batch_context_with_kv_cache
(
query
=
query
,
kv_cache
=
kv_cache
,
workspace_buffer
=
workspace_buffer
,
block_tables
=
block_tables
,
seq_lens
=
seq_lens
,
max_q_len
=
max_q_len
,
max_kv_len
=
max_seq_len
,
bmm1_scale
=
q_scale
*
k_scale
*
sm_scale
,
bmm2_scale
=
v_scale
/
o_scale
,
batch_size
=
batch_size
,
cum_seq_lens_q
=
q_indptr
,
cum_seq_lens_kv
=
kv_indptr
,
o_sf_scale
=
o_sf_scale
,
out
=
output_trtllm
,
)
baseline_mean
,
baseline_std
=
time_fn
(
baseline_prefill
)
trtllm_mean
,
trtllm_std
=
time_fn
(
trtllm_prefill
)
# Calculate percentage speedup (positive means TRT is faster)
speedup_percent
=
(
baseline_mean
-
trtllm_mean
)
/
baseline_mean
print
(
f
"
\t
{
batch_size
}
\t
{
max_seq_len
}
\t
{
trtllm_mean
:
8.3
f
}
\t
{
trtllm_std
.
item
():
8.3
f
}
"
f
"
\t
{
baseline_mean
:
8.3
f
}
\t
{
baseline_std
.
item
():
8.3
f
}
\t
{
speedup_percent
:
8.3
f
}
"
)
# Return results for CSV writing
return
{
"batch_size"
:
batch_size
,
"trtllm_mean"
:
trtllm_mean
,
"trtllm_std"
:
trtllm_std
.
item
(),
"baseline_mean"
:
baseline_mean
,
"baseline_std"
:
baseline_std
.
item
(),
"speedup_percent"
:
speedup_percent
,
"q_dtype"
:
str
(
q_quant_dtype
),
"kv_cache_dtype"
:
str
(
kv_quant_dtype
),
"output_dtype"
:
str
(
o_quant_dtype
),
"block_size"
:
block_size
,
"num_kv_heads"
:
num_kv_heads
,
"head_size"
:
head_size
,
"max_seq_len"
:
max_seq_len
,
}
def
write_results_to_csv
(
results
,
filename
=
None
):
"""Write benchmark results to CSV file."""
if
filename
is
None
:
timestamp
=
datetime
.
now
().
strftime
(
"%Y%m%d_%H%M%S"
)
filename
=
f
"flashinfer_trtllm_benchmark_
{
timestamp
}
.csv"
fieldnames
=
[
"batch_size"
,
"trtllm_mean"
,
"trtllm_std"
,
"baseline_mean"
,
"baseline_std"
,
"speedup_percent"
,
"q_dtype"
,
"kv_cache_dtype"
,
"output_dtype"
,
"block_size"
,
"num_kv_heads"
,
"head_size"
,
"max_seq_len"
,
]
file_exists
=
os
.
path
.
exists
(
filename
)
with
open
(
filename
,
"a"
,
newline
=
""
)
as
csvfile
:
writer
=
csv
.
DictWriter
(
csvfile
,
fieldnames
=
fieldnames
)
if
not
file_exists
:
writer
.
writeheader
()
for
result
in
results
:
writer
.
writerow
(
result
)
print
(
f
"Results written to
{
filename
}
"
)
if
__name__
==
"__main__"
:
batch_sizes
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
]
max_seq_lens
=
[
1024
,
2048
,
4096
,
8192
,
16384
,
32768
,
65536
,
131072
]
all_results
=
[]
dtype
=
torch
.
bfloat16
quant_dtypes
=
[
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(
None
,
None
,
None
),
(
FP8_DTYPE
,
FP8_DTYPE
,
None
),
(
FP8_DTYPE
,
FP8_DTYPE
,
FP8_DTYPE
),
(
FP8_DTYPE
,
FP8_DTYPE
,
FP4_DTYPE
),
]
for
quant_dtype
in
quant_dtypes
:
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtype
q_quant_dtype
=
q_quant_dtype
or
dtype
kv_quant_dtype
=
kv_quant_dtype
or
dtype
o_quant_dtype
=
o_quant_dtype
or
dtype
print
(
f
"Running benchmark for q_dtype =
{
q_quant_dtype
}
, "
f
"kv_cache_dtype:
{
kv_quant_dtype
}
, "
f
"output_dtype:
{
o_quant_dtype
}
"
)
print
(
"
\t
batch_size
\t
max_seq_len
\t
trtllm_mean
\t
trtllm_std
\t
baseline_mean
\t
"
"baseline_std
\t
speedup_percent"
)
for
max_seq_len
in
max_seq_lens
:
for
bs
in
batch_sizes
:
result
=
benchmark_prefill
(
dtype
=
dtype
,
quant_dtypes
=
quant_dtype
,
batch_size
=
bs
,
max_seq_len
=
max_seq_len
,
)
all_results
.
append
(
result
)
# Write all results to CSV
write_results_to_csv
(
all_results
)
benchmarks/kernels/benchmark_w8a8_block_fp8.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from sglang quantization/tuning_block_wise_kernel.py
import
argparse
import
json
import
multiprocessing
as
mp
import
os
import
time
from
datetime
import
datetime
from
typing
import
Any
import
torch
from
tqdm
import
tqdm
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
_w8a8_triton_block_scaled_mm
,
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
assert
current_platform
.
is_cuda
()
or
current_platform
.
is_rocm
(),
(
"Only support tune w8a8 block fp8 kernel on CUDA/ROCm device."
)
DTYPE_MAP
=
{
"float32"
:
torch
.
float32
,
"float16"
:
torch
.
float16
,
"half"
:
torch
.
half
,
"bfloat16"
:
torch
.
bfloat16
,
}
def
w8a8_block_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
list
[
int
],
config
:
dict
[
str
,
Any
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
"""This function performs matrix multiplication with
block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization.
It should be 2-dim, e.g., [128, 128].
output_dtype: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
and
A
.
is_contiguous
()
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
if
A
.
dtype
==
torch
.
float8_e4m3fn
:
kernel
=
_w8a8_triton_block_scaled_mm
else
:
raise
RuntimeError
(
"Currently, only support tune w8a8 block fp8 kernel."
)
kernel
[
grid
](
A
,
B
,
C
,
As
,
Bs
,
M
,
N
,
K
,
block_n
,
block_k
,
A
.
stride
(
-
2
),
A
.
stride
(
-
1
),
B
.
stride
(
1
),
B
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
As
.
stride
(
-
2
),
As
.
stride
(
-
1
),
Bs
.
stride
(
1
),
Bs
.
stride
(
0
),
**
config
,
)
return
C
def
get_configs_compute_bound
():
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
]:
for
block_m
in
[
16
,
32
,
64
,
128
,
256
]:
for
block_k
in
[
64
,
128
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
for
num_warps
in
[
4
,
8
]:
for
group_size
in
[
1
,
16
,
32
,
64
]:
configs
.
append
(
{
"BLOCK_SIZE_M"
:
block_m
,
"BLOCK_SIZE_N"
:
block_n
,
"BLOCK_SIZE_K"
:
block_k
,
"GROUP_SIZE_M"
:
group_size
,
"num_warps"
:
num_warps
,
"num_stages"
:
num_stages
,
}
)
return
configs
def
get_weight_shapes
(
tp_size
):
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3.
# Modify them, if you tune for another different model.
# cannot TP
total
=
[
(
512
+
64
,
7168
),
(
2112
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
7168
,
16384
),
(
7168
,
18432
),
]
# N can TP
n_tp
=
[
(
18432
*
2
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
24576
,
1536
),
(
12288
,
7168
),
(
4096
,
7168
),
]
# K can TP
k_tp
=
[(
7168
,
18432
),
(
7168
,
16384
),
(
7168
,
2048
)]
weight_shapes
=
[]
for
t
in
total
:
weight_shapes
.
append
(
t
)
for
n_t
in
n_tp
:
new_t
=
(
n_t
[
0
]
//
tp_size
,
n_t
[
1
])
weight_shapes
.
append
(
new_t
)
for
k_t
in
k_tp
:
new_t
=
(
k_t
[
0
],
k_t
[
1
]
//
tp_size
)
weight_shapes
.
append
(
new_t
)
return
weight_shapes
def
benchmark_config
(
A
,
B
,
As
,
Bs
,
block_size
,
config
,
out_dtype
=
torch
.
float16
,
num_iters
=
10
):
def
run
():
w8a8_block_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
config
,
out_dtype
)
torch
.
cuda
.
synchronize
()
# JIT complication & warmup
for
_
in
range
(
5
):
run
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
torch
.
cuda
.
synchronize
()
start_event
.
record
()
run
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
return
avg
def
tune
(
M
,
N
,
K
,
block_size
,
out_dtype
,
search_space
,
input_type
):
factor_for_scale
=
1e-2
if
input_type
==
"fp8"
:
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
A_fp32
=
(
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
)
A
=
A_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
B_fp32
=
(
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
)
B
=
B_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
else
:
raise
RuntimeError
(
"Currently, only support tune w8a8 block fp8 kernel."
)
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
As
=
torch
.
rand
(
M
,
k_tiles
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
factor_for_scale
Bs
=
(
torch
.
rand
(
n_tiles
,
k_tiles
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
factor_for_scale
)
best_config
=
None
best_time
=
float
(
"inf"
)
for
config
in
tqdm
(
search_space
):
try
:
kernel_time
=
benchmark_config
(
A
,
B
,
As
,
Bs
,
block_size
,
config
,
out_dtype
,
num_iters
=
10
,
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
continue
if
kernel_time
<
best_time
:
best_time
=
kernel_time
best_config
=
config
now
=
datetime
.
now
()
print
(
f
"
{
now
.
ctime
()
}
] Completed tuning for batch_size=
{
M
}
"
)
assert
best_config
is
not
None
return
best_config
def
save_configs
(
N
,
K
,
block_n
,
block_k
,
configs
,
save_path
,
input_type
=
"fp8"
,
)
->
None
:
os
.
makedirs
(
save_path
,
exist_ok
=
True
)
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
json_file_name
=
(
f
"N=
{
N
}
,K=
{
K
}
,device_name=
{
device_name
}
,dtype=
{
input_type
}
_w8a8,"
f
"block_shape=[
{
block_n
}
,
{
block_k
}
].json"
)
config_file_path
=
os
.
path
.
join
(
save_path
,
json_file_name
)
print
(
f
"Writing best config to
{
config_file_path
}
..."
)
with
open
(
config_file_path
,
"w"
)
as
f
:
json
.
dump
(
configs
,
f
,
indent
=
4
)
f
.
write
(
"
\n
"
)
def
tune_on_gpu
(
args_dict
):
"""Run tuning on a specific GPU."""
gpu_id
=
args_dict
[
"gpu_id"
]
batch_sizes
=
args_dict
[
"batch_sizes"
]
weight_shapes
=
args_dict
[
"weight_shapes"
]
args
=
args_dict
[
"args"
]
torch
.
cuda
.
set_device
(
gpu_id
)
print
(
f
"Starting tuning on GPU
{
gpu_id
}
with batch sizes
{
batch_sizes
}
"
)
block_n
=
args
.
block_n
block_k
=
args
.
block_k
out_dtype
=
DTYPE_MAP
[
args
.
out_dtype
]
save_path
=
args
.
save_path
input_type
=
args
.
input_type
search_space
=
get_configs_compute_bound
()
search_space
=
[
config
for
config
in
search_space
if
block_k
%
config
[
"BLOCK_SIZE_K"
]
==
0
]
start
=
time
.
time
()
for
shape
in
tqdm
(
weight_shapes
,
desc
=
f
"GPU
{
gpu_id
}
- Shapes"
):
N
,
K
=
shape
[
0
],
shape
[
1
]
print
(
f
"[GPU
{
gpu_id
}
] Tune for weight shape of `N:
{
N
}
, K:
{
K
}
`"
)
benchmark_results
=
[
tune
(
batch_size
,
N
,
K
,
[
block_n
,
block_k
],
out_dtype
,
search_space
,
input_type
,
)
for
batch_size
in
tqdm
(
batch_sizes
,
desc
=
f
"GPU
{
gpu_id
}
- Batch sizes"
)
]
best_configs
=
{
M
:
config
for
M
,
config
in
zip
(
batch_sizes
,
benchmark_results
)}
save_configs
(
N
,
K
,
block_n
,
block_k
,
best_configs
,
save_path
,
input_type
)
end
=
time
.
time
()
print
(
f
"Tuning on GPU
{
gpu_id
}
took
{
end
-
start
:.
2
f
}
seconds"
)
def
distribute_batch_sizes
(
batch_sizes
,
num_gpus
):
"""Distribute batch sizes across available GPUs."""
batches_per_gpu
=
[]
for
i
in
range
(
num_gpus
):
start_idx
=
i
*
len
(
batch_sizes
)
//
num_gpus
end_idx
=
(
i
+
1
)
*
len
(
batch_sizes
)
//
num_gpus
batches_per_gpu
.
append
(
batch_sizes
[
start_idx
:
end_idx
])
return
batches_per_gpu
def
main
(
args
):
print
(
args
)
num_gpus
=
torch
.
cuda
.
device_count
()
if
num_gpus
==
0
:
raise
RuntimeError
(
"No GPU available for tuning"
)
print
(
f
"Found
{
num_gpus
}
GPUs for parallel tuning"
)
torch
.
cuda
.
init
()
if
args
.
batch_size
is
None
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
,
]
else
:
batch_sizes
=
[
args
.
batch_size
]
num_gpus
=
1
# If only one batch size, use only one GPU
weight_shapes
=
get_weight_shapes
(
args
.
tp_size
)
batches_per_gpu
=
distribute_batch_sizes
(
batch_sizes
,
num_gpus
)
process_args
=
[]
for
gpu_id
in
range
(
num_gpus
):
process_args
.
append
(
{
"gpu_id"
:
gpu_id
,
"batch_sizes"
:
batches_per_gpu
[
gpu_id
],
"weight_shapes"
:
weight_shapes
,
# Each GPU processes all weight shapes
"args"
:
args
,
}
)
ctx
=
mp
.
get_context
(
"spawn"
)
with
ctx
.
Pool
(
num_gpus
)
as
pool
:
pool
.
map
(
tune_on_gpu
,
process_args
)
print
(
"Multi-GPU tuning completed"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"""
Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
Then copy to model_executor/layers/quantization/utils/configs
"""
,
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--input-type"
,
type
=
str
,
choices
=
[
"fp8"
],
default
=
"fp8"
)
parser
.
add_argument
(
"--out-dtype"
,
type
=
str
,
choices
=
[
"float32"
,
"float16"
,
"bfloat16"
,
"half"
],
default
=
"float16"
,
)
parser
.
add_argument
(
"--block-n"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--block-k"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./"
)
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/cpu/benchmark_cpu_attn.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
time
import
numpy
as
np
import
torch
from
vllm._custom_ops
import
(
cpu_attention_with_kv_cache
,
cpu_attn_get_scheduler_metadata
,
cpu_attn_reshape_and_cache
,
)
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
set_random_seed
from
vllm.v1.attention.backends.cpu_attn
import
CPUAttentionBackend
,
_get_attn_isa
def
get_attn_isa
(
block_size
:
int
|
None
=
None
,
dtype
:
torch
.
dtype
|
None
=
None
,
):
if
block_size
and
dtype
:
return
_get_attn_isa
(
dtype
,
block_size
)
else
:
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
ARM
:
return
"neon"
elif
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
():
return
"amx"
else
:
return
"vec"
# rand number generation takes too much time, cache rand tensors
@
functools
.
lru_cache
(
maxsize
=
128
,
typed
=
False
)
def
tensor_cache
(
elem_num
:
int
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
tensor
=
torch
.
randn
(
elem_num
,
dtype
=
dtype
)
return
tensor
@
torch
.
inference_mode
()
def
main
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
int
=
None
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
block_size
:
int
=
128
,
num_blocks
:
int
=
4096
,
use_sink
:
bool
=
False
,
enable_kv_split
:
bool
=
False
,
isa
:
str
|
None
=
None
,
seed
:
int
=
0
,
iters
:
int
=
20
,
)
->
None
:
set_random_seed
(
seed
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
window_size
=
(
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
)
scale
=
head_size
**-
0.5
token_num
=
sum
(
query_lens
)
if
isa
is
None
:
isa
=
get_attn_isa
(
block_size
,
dtype
)
s_aux
=
(
15
*
torch
.
rand
((
num_query_heads
,),
dtype
=
torch
.
bfloat16
)
if
use_sink
else
None
)
query
=
tensor_cache
(
elem_num
=
token_num
*
num_query_heads
*
head_size
,
dtype
=
dtype
,
)
query
=
query
.
view
(
token_num
,
num_query_heads
,
head_size
,
)
key_value
=
tensor_cache
(
elem_num
=
2
*
num_blocks
*
num_kv_heads
*
block_size
*
head_size
,
dtype
=
dtype
,
)
key_value
=
key_value
.
view
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
)
key_cache
,
value_cache
=
key_value
.
unbind
(
0
)
# KV cache for CPU attention
packed_key_cache
=
torch
.
empty
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
,
dtype
=
dtype
)
packed_value_cache
=
torch
.
empty_like
(
packed_key_cache
)
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
# use reshape_and_cache to pack key_cache and value_cache
slot_mapping
=
torch
.
arange
(
0
,
num_blocks
*
block_size
,
dtype
=
torch
.
int64
)
cpu_attn_reshape_and_cache
(
key
=
key_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
),
value
=
value_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
),
key_cache
=
packed_key_cache
,
value_cache
=
packed_value_cache
,
slot_mapping
=
slot_mapping
,
isa
=
isa
,
)
metadata
=
cpu_attn_get_scheduler_metadata
(
num_reqs
=
num_seqs
,
num_heads
=
num_query_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim
=
head_size
,
seq_lens
=
kv_lens_tensor
,
dtype
=
dtype
,
query_start_loc
=
cu_query_lens
,
causal
=
True
,
sliding_window_size
=
sliding_window
if
sliding_window
is
not
None
else
-
1
,
isa
=
isa
,
enable_kv_split
=
enable_kv_split
,
)
out_with_split
=
torch
.
empty_like
(
query
)
def
run_benchmark
(
iters
:
int
)
->
list
[
float
]:
times
=
[]
for
_
in
range
(
iters
):
start_time
=
time
.
perf_counter_ns
()
cpu_attention_with_kv_cache
(
query
=
query
,
key_cache
=
packed_key_cache
,
value_cache
=
packed_value_cache
,
output
=
out_with_split
,
query_start_loc
=
cu_query_lens
,
seq_lens
=
kv_lens_tensor
,
scale
=
scale
,
causal
=
True
,
alibi_slopes
=
None
,
sliding_window
=
window_size
,
block_table
=
block_tables
,
softcap
=
0
,
scheduler_metadata
=
metadata
,
s_aux
=
s_aux
,
)
end_time
=
time
.
perf_counter_ns
()
times
.
append
((
end_time
-
start_time
)
/
1e6
)
return
times
# warmup
run_benchmark
(
5
)
# benchmark
times
=
run_benchmark
(
iters
)
time_min
=
min
(
times
)
time_max
=
max
(
times
)
time_mean
=
np
.
mean
(
times
)
time_std
=
np
.
std
(
times
)
print
(
"
\t
min (ms) = "
,
time_min
)
print
(
"
\t
max (ms) = "
,
time_max
)
print
(
"
\t
mean (ms) = "
,
time_mean
)
print
(
"
\t
std = "
,
time_std
)
print
(
"
\t
median (ms) = "
,
np
.
median
(
times
))
def
generate_seq_lens
(
batch_size
:
int
,
q_len_min
:
int
,
q_len_max
:
int
,
kv_len_min
:
int
,
kv_len_max
:
int
,
seed
:
int
=
0
,
)
->
list
[
tuple
[
int
,
int
]]:
assert
1
<=
q_len_min
<=
q_len_max
assert
1
<=
kv_len_min
<=
kv_len_max
assert
kv_len_max
>=
q_len_min
g
=
torch
.
Generator
(
device
=
"cpu"
).
manual_seed
(
seed
)
def
rint
(
lo
:
int
,
hi
:
int
)
->
int
:
return
torch
.
randint
(
lo
,
hi
+
1
,
(
1
,),
generator
=
g
).
item
()
seq_lens
:
list
[
tuple
[
int
,
int
]]
=
[]
for
_
in
range
(
batch_size
):
# ensure q <= kv
kv
=
rint
(
max
(
kv_len_min
,
q_len_min
),
kv_len_max
)
q
=
rint
(
q_len_min
,
min
(
q_len_max
,
kv
))
seq_lens
.
append
((
q
,
kv
))
return
seq_lens
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the paged attention kernel."
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--q-len-min"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--q-len-max"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--kv-len-min"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--kv-len-max"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--num-blocks"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--sliding-window"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
CPUAttentionBackend
.
get_supported_head_sizes
(),
default
=
128
,
)
parser
.
add_argument
(
"--enable-kv-split"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
32
,
64
,
128
],
default
=
128
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"bfloat16"
)
parser
.
add_argument
(
"--use-sink"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--isa"
,
type
=
str
,
choices
=
[
"vec"
,
"neon"
,
"amx"
,
"vec16"
],
default
=
None
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--iters"
,
type
=
int
,
default
=
20
)
args
=
parser
.
parse_args
()
print
(
args
)
seq_lens
=
generate_seq_lens
(
args
.
batch_size
,
args
.
q_len_min
,
args
.
q_len_max
,
args
.
kv_len_min
,
args
.
kv_len_max
,
args
.
seed
,
)
print
(
"batch (query len, kv len) = "
,
seq_lens
)
main
(
seq_lens
=
seq_lens
,
num_heads
=
(
args
.
num_query_heads
,
args
.
num_kv_heads
),
head_size
=
args
.
head_size
,
sliding_window
=
args
.
sliding_window
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
block_size
=
args
.
block_size
,
num_blocks
=
args
.
num_blocks
,
use_sink
=
args
.
use_sink
,
enable_kv_split
=
args
.
enable_kv_split
,
isa
=
args
.
isa
if
args
.
isa
is
not
None
else
get_attn_isa
(
args
.
block_size
,
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
]),
seed
=
args
.
seed
,
iters
=
args
.
iters
,
)
benchmarks/kernels/cpu/benchmark_cpu_fused_moe.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
sys
import
time
import
numpy
as
np
import
torch
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
set_random_seed
# Check if CPU MoE operations are available
try
:
from
vllm._custom_ops
import
cpu_fused_moe
,
cpu_prepack_moe_weight
except
(
ImportError
,
AttributeError
)
as
e
:
print
(
"ERROR: CPU fused MoE operations are not available on this platform."
)
print
(
"This benchmark requires x86 CPU with proper vLLM CPU extensions compiled."
)
print
(
"The cpu_fused_moe kernel is typically available on Linux x86_64 "
"with AVX2/AVX512."
)
print
(
f
"Import error:
{
e
}
"
)
sys
.
exit
(
1
)
# ISA selection following test_cpu_fused_moe.py pattern
ISA_CHOICES
=
[
"amx"
,
"vec"
]
if
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
else
[
"vec"
]
@
torch
.
inference_mode
()
def
main
(
batch_size
:
int
,
expert_num
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
topk_num
:
int
,
use_bias
:
bool
=
False
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
activation
:
str
=
"silu"
,
isa
:
str
=
"vec"
,
seed
:
int
=
0
,
iters
:
int
=
20
,
)
->
None
:
set_random_seed
(
seed
)
# up_dim = 2 * intermediate_size for gate + up projection
up_dim
=
2
*
intermediate_size
input_tensor
=
torch
.
randn
((
batch_size
,
hidden_size
),
dtype
=
dtype
)
/
(
0.5
*
hidden_size
**
0.5
)
w13
=
torch
.
randn
((
expert_num
,
up_dim
,
hidden_size
),
dtype
=
dtype
)
/
(
0.5
*
hidden_size
**
0.5
)
w2
=
torch
.
randn
((
expert_num
,
hidden_size
,
intermediate_size
),
dtype
=
dtype
)
/
(
0.5
*
intermediate_size
**
0.5
)
w13_bias
=
None
w2_bias
=
None
if
use_bias
:
w13_bias
=
torch
.
randn
((
expert_num
,
up_dim
),
dtype
=
dtype
)
/
(
0.5
*
up_dim
**
0.5
)
w2_bias
=
torch
.
randn
((
expert_num
,
hidden_size
),
dtype
=
dtype
)
/
(
0.5
*
hidden_size
**
0.5
)
router_logits
=
torch
.
randn
((
batch_size
,
expert_num
),
dtype
=
dtype
)
score
=
torch
.
softmax
(
router_logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
=
torch
.
topk
(
score
,
topk_num
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
packed_w13
=
cpu_prepack_moe_weight
(
w13
,
isa
)
packed_w2
=
cpu_prepack_moe_weight
(
w2
,
isa
)
def
run_benchmark
(
iters
:
int
)
->
list
[
float
]:
times
=
[]
for
_
in
range
(
iters
):
start_time
=
time
.
perf_counter_ns
()
_
=
cpu_fused_moe
(
input_tensor
,
packed_w13
,
packed_w2
,
w13_bias
,
w2_bias
,
topk_weights
,
topk_ids
,
activation
,
isa
,
)
end_time
=
time
.
perf_counter_ns
()
times
.
append
((
end_time
-
start_time
)
/
1e6
)
return
times
# warmup
run_benchmark
(
5
)
# benchmark
times
=
run_benchmark
(
iters
)
if
not
times
:
print
(
"No iterations to measure. Set --iters > 0."
)
return
time_min
=
min
(
times
)
time_max
=
max
(
times
)
time_mean
=
np
.
mean
(
times
)
time_std
=
np
.
std
(
times
)
print
(
"
\t
min (ms) = "
,
time_min
)
print
(
"
\t
max (ms) = "
,
time_max
)
print
(
"
\t
mean (ms) = "
,
time_mean
)
print
(
"
\t
std = "
,
time_std
)
print
(
"
\t
median (ms) = "
,
np
.
median
(
times
))
# Calculate throughput metrics
# FLOPs estimation: 2 * batch * topk * (hidden * up_dim + intermediate * hidden)
flops_per_token
=
(
2
*
topk_num
*
(
hidden_size
*
up_dim
+
intermediate_size
*
hidden_size
)
)
total_flops
=
batch_size
*
flops_per_token
tflops
=
total_flops
/
(
time_mean
*
1e-3
)
/
1e12
print
(
f
"
\t
throughput (TFLOP/s) =
{
tflops
:.
4
f
}
"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the CPU fused MoE kernel."
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--expert-num"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
2880
)
parser
.
add_argument
(
"--intermediate-size"
,
type
=
int
,
default
=
2880
)
parser
.
add_argument
(
"--topk-num"
,
type
=
int
,
default
=
None
,
help
=
"Number of experts to route each token to (default: expert_num // 2)"
,
)
parser
.
add_argument
(
"--use-bias"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--activation"
,
type
=
str
,
choices
=
[
"silu"
,
"swigluoai"
],
default
=
"silu"
,
help
=
"Activation function"
,
)
parser
.
add_argument
(
"--isa"
,
type
=
str
,
choices
=
ISA_CHOICES
,
default
=
ISA_CHOICES
[
0
],
help
=
f
"ISA to use (available:
{
ISA_CHOICES
}
)"
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--iters"
,
type
=
int
,
default
=
20
)
args
=
parser
.
parse_args
()
# Default topk_num to expert_num // 2, minimum 1
topk_num
=
(
args
.
topk_num
if
args
.
topk_num
is
not
None
else
max
(
args
.
expert_num
//
2
,
1
)
)
print
(
args
)
main
(
batch_size
=
args
.
batch_size
,
expert_num
=
args
.
expert_num
,
hidden_size
=
args
.
hidden_size
,
intermediate_size
=
args
.
intermediate_size
,
topk_num
=
topk_num
,
use_bias
=
args
.
use_bias
,
dtype
=
torch
.
bfloat16
,
# Following test_cpu_fused_moe.py
activation
=
args
.
activation
,
isa
=
args
.
isa
,
seed
=
args
.
seed
,
iters
=
args
.
iters
,
)
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment