Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
raojy
vllm_017
Commits
fbeb8a6f
Commit
fbeb8a6f
authored
Mar 27, 2026
by
raojy
Browse files
raw_vllm
parent
2ca8867f
Pipeline
#3454
canceled with stages
Changes
165
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
8574 additions
and
0 deletions
+8574
-0
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
+244
-0
benchmarks/kernels/benchmark_activation.py
benchmarks/kernels/benchmark_activation.py
+106
-0
benchmarks/kernels/benchmark_block_fp8_gemm.py
benchmarks/kernels/benchmark_block_fp8_gemm.py
+162
-0
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
+352
-0
benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
+540
-0
benchmarks/kernels/benchmark_device_communicators.py
benchmarks/kernels/benchmark_device_communicators.py
+571
-0
benchmarks/kernels/benchmark_fp8_gemm.py
benchmarks/kernels/benchmark_fp8_gemm.py
+159
-0
benchmarks/kernels/benchmark_fused_collective.py
benchmarks/kernels/benchmark_fused_collective.py
+1137
-0
benchmarks/kernels/benchmark_fused_topk.py
benchmarks/kernels/benchmark_fused_topk.py
+99
-0
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+429
-0
benchmarks/kernels/benchmark_int8_gemm.py
benchmarks/kernels/benchmark_int8_gemm.py
+169
-0
benchmarks/kernels/benchmark_layernorm.py
benchmarks/kernels/benchmark_layernorm.py
+95
-0
benchmarks/kernels/benchmark_lora.py
benchmarks/kernels/benchmark_lora.py
+1490
-0
benchmarks/kernels/benchmark_machete.py
benchmarks/kernels/benchmark_machete.py
+745
-0
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+365
-0
benchmarks/kernels/benchmark_mla_k_concat.py
benchmarks/kernels/benchmark_mla_k_concat.py
+150
-0
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+1041
-0
benchmarks/kernels/benchmark_moe_align_block_size.py
benchmarks/kernels/benchmark_moe_align_block_size.py
+87
-0
benchmarks/kernels/benchmark_moe_defaults.py
benchmarks/kernels/benchmark_moe_defaults.py
+278
-0
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+355
-0
No files found.
Too many changes to show.
To preserve performance only
165 of 165+
files are displayed.
Plain diff
Email patch
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
enum
import
Enum
from
itertools
import
product
from
typing
import
Any
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
_per_token_group_quant_fp8_colmajor
,
silu_mul_per_token_group_quant_fp8_colmajor
,
)
from
vllm.triton_utils
import
triton
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
from
.utils
import
ArgPool
,
Bench
,
CudaGraphBenchParams
GROUP_SIZE
=
128
FLOAT8_T
=
torch
.
float8_e4m3fn
def
print_timers
(
timers
:
list
[
TMeasurement
],
cuda_graph_nops
:
int
):
print
(
f
"Note : The timings reported above is for
{
cuda_graph_nops
}
"
"consecutive invocations of the benchmarking functions. "
f
"Please divide by
{
cuda_graph_nops
}
for single invocation "
"timings."
)
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
class
ImplType
(
Enum
):
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
=
1
REFERENCE
=
2
def
get_impl
(
self
):
if
self
==
ImplType
.
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
:
return
silu_mul_per_token_group_quant_fp8_colmajor
elif
self
==
ImplType
.
REFERENCE
:
return
reference
raise
ValueError
(
f
"Unrecognized ImplType
{
self
}
"
)
@
dataclass
class
BenchmarkTensors
:
input
:
torch
.
Tensor
output
:
torch
.
Tensor
# Reference act output tensor
ref_act_out
:
torch
.
Tensor
ref_quant_out
:
torch
.
Tensor
@
staticmethod
def
make
(
T
:
int
,
N
:
int
)
->
"BenchmarkTensors"
:
assert
T
%
GROUP_SIZE
==
0
assert
N
%
(
GROUP_SIZE
*
2
)
==
0
input
=
torch
.
rand
((
T
,
N
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
# silu_mul_per_token_group_quant_fp8_colmajor output.
output
=
torch
.
rand
((
T
,
N
//
2
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
to
(
FLOAT8_T
)
# reference output.
ref_act_out
=
torch
.
empty
((
T
,
N
//
2
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
ref_quant_out
=
torch
.
empty
(
(
T
,
N
//
2
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
to
(
FLOAT8_T
)
return
BenchmarkTensors
(
input
=
input
,
output
=
output
,
ref_act_out
=
ref_act_out
,
ref_quant_out
=
ref_quant_out
,
)
@
property
def
T
(
self
):
return
self
.
input
.
size
(
0
)
@
property
def
N
(
self
):
return
self
.
input
.
size
(
1
)
def
make_impl_kwargs
(
self
,
impl_type
:
ImplType
)
->
dict
[
str
,
Any
]:
if
impl_type
==
ImplType
.
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
:
return
{
"input"
:
self
.
input
,
"output"
:
self
.
output
,
"use_ue8m0"
:
is_deep_gemm_e8m0_used
(),
}
elif
impl_type
==
ImplType
.
REFERENCE
:
return
{
"input"
:
self
.
input
,
"act_out"
:
self
.
ref_act_out
,
"quant_out"
:
self
.
ref_quant_out
,
"use_ue8m0"
:
is_deep_gemm_e8m0_used
(),
}
raise
ValueError
(
f
"Unrecognized impl_type
{
impl_type
}
"
)
def
reference_quant
(
x
:
torch
.
Tensor
,
quant_out
:
torch
.
Tensor
,
use_ue8m0
:
bool
):
"""
Reference triton quant kernel from,
vllm.model_executor.layers.quantization.utils.fp8_utils
"""
assert
quant_out
.
size
()
==
x
.
size
()
# Allocate the scale tensor column-major format.
shape
=
(
x
.
shape
[
-
1
]
//
GROUP_SIZE
,)
+
x
.
shape
[:
-
1
]
x_q
=
quant_out
x_s
=
torch
.
empty
(
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float32
).
permute
(
-
1
,
-
2
)
M
=
x
.
numel
()
//
GROUP_SIZE
N
=
GROUP_SIZE
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
finfo
=
torch
.
finfo
(
FLOAT8_T
)
fp8_min
=
finfo
.
min
fp8_max
=
finfo
.
max
_per_token_group_quant_fp8_colmajor
[(
M
,)](
x
,
x_q
,
x_s
,
GROUP_SIZE
,
x
.
shape
[
1
],
x
.
stride
(
0
),
x_s
.
stride
(
1
),
eps
=
1e-10
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
use_ue8m0
=
use_ue8m0
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
def
reference
(
input
:
torch
.
Tensor
,
act_out
:
torch
.
Tensor
,
quant_out
:
torch
.
Tensor
,
use_ue8m0
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
torch
.
ops
.
_C
.
silu_and_mul
(
act_out
,
input
)
return
reference_quant
(
act_out
,
quant_out
,
use_ue8m0
)
def
bench_impl
(
bench_tensors
:
list
[
BenchmarkTensors
],
impl_type
:
ImplType
)
->
TMeasurement
:
T
=
bench_tensors
[
0
].
T
N
=
bench_tensors
[
0
].
N
arg_pool_size
=
len
(
bench_tensors
)
kwargs_list
=
[
bt
.
make_impl_kwargs
(
impl_type
)
for
bt
in
bench_tensors
]
# warmup
for
kwargs
in
kwargs_list
:
impl_type
.
get_impl
()(
**
kwargs
)
torch
.
cuda
.
synchronize
()
# Merge into a single kwargs and qualify arguments as ArgPool
kwargs
=
{
k
:
ArgPool
([])
for
k
in
kwargs_list
[
0
]}
for
_kwargs
in
kwargs_list
:
for
k
,
v
in
_kwargs
.
items
():
kwargs
[
k
].
values
.
append
(
v
)
cuda_graph_params
=
None
cuda_graph_params
=
CudaGraphBenchParams
(
arg_pool_size
)
timer
=
None
with
Bench
(
cuda_graph_params
,
"silu-mul-quant"
,
f
"num_tokens=
{
T
}
, N=
{
N
}
"
,
impl_type
.
name
,
impl_type
.
get_impl
(),
**
kwargs
,
)
as
bench
:
timer
=
bench
.
run
()
return
timer
def
test_correctness
(
T
:
int
,
N
:
int
):
print
(
f
"Testing num_tokens=
{
T
}
, N=
{
N
}
..."
)
bench_tensor
=
BenchmarkTensors
.
make
(
T
,
N
)
def
output_from_impl
(
impl
:
ImplType
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
impl
.
get_impl
()(
**
bench_tensor
.
make_impl_kwargs
(
impl
))
# reference output
ref_out_q
,
ref_out_s
=
output_from_impl
(
ImplType
.
REFERENCE
)
# test ouptut
out_q
,
out_s
=
output_from_impl
(
ImplType
.
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
)
torch
.
testing
.
assert_close
(
ref_out_q
.
to
(
torch
.
float32
),
out_q
.
to
(
torch
.
float32
))
torch
.
testing
.
assert_close
(
ref_out_s
,
out_s
)
def
run
(
Ts
:
list
[
int
],
Ns
:
list
[
int
],
arg_pool_size
:
int
)
->
list
[
TMeasurement
]:
timers
=
[]
for
N
,
T
in
product
(
Ns
,
Ts
):
test_correctness
(
T
,
N
)
bench_tensors
:
list
[
BenchmarkTensors
]
=
[
BenchmarkTensors
.
make
(
T
,
N
)
for
_
in
range
(
arg_pool_size
)
]
silu_mul_quant_timer
=
bench_impl
(
bench_tensors
,
ImplType
.
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
)
timers
.
append
(
silu_mul_quant_timer
)
reference_timer
=
bench_impl
(
bench_tensors
,
ImplType
.
REFERENCE
)
timers
.
append
(
reference_timer
)
print_timers
(
[
silu_mul_quant_timer
,
reference_timer
],
cuda_graph_nops
=
arg_pool_size
)
print_timers
(
timers
,
cuda_graph_nops
=
arg_pool_size
)
return
timers
if
__name__
==
"__main__"
:
T
=
[
128
*
i
for
i
in
range
(
1
,
16
)]
+
[
2048
*
i
for
i
in
range
(
1
,
65
)]
N
=
[
2048
,
4096
,
8192
]
print
(
f
"T =
{
T
}
, N =
{
N
}
"
)
run
(
T
,
N
,
arg_pool_size
=
8
)
benchmarks/kernels/benchmark_activation.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# benchmark custom activation op performance
import
itertools
import
torch
import
vllm.model_executor.layers.activation
# noqa F401
from
vllm.benchmarks.lib.utils
import
default_vllm_config
from
vllm.model_executor.custom_op
import
op_registry
from
vllm.triton_utils
import
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
set_random_seed
batch_size_range
=
[
1
,
16
,
128
]
seq_len_range
=
[
1
,
16
,
64
,
1024
,
4096
]
intermediate_size
=
[
3072
,
9728
,
12288
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
intermediate_size
))
@
default_vllm_config
()
def
benchmark_activation
(
batch_size
:
int
,
seq_len
:
int
,
intermediate_size
:
int
,
provider
:
str
,
func_name
:
str
,
dtype
:
torch
.
dtype
,
):
device
=
"cuda"
num_tokens
=
batch_size
*
seq_len
dim
=
intermediate_size
set_random_seed
(
42
)
torch
.
set_default_device
(
device
)
if
func_name
==
"gelu_and_mul"
:
layer
=
op_registry
[
func_name
](
approximate
=
"none"
)
elif
func_name
==
"gelu_and_mul_tanh"
:
layer
=
op_registry
[
"gelu_and_mul"
](
approximate
=
"tanh"
)
elif
func_name
==
"fatrelu_and_mul"
:
threshold
=
0.5
layer
=
op_registry
[
func_name
](
threshold
)
else
:
layer
=
op_registry
[
func_name
]()
x
=
torch
.
randn
(
num_tokens
,
dim
,
dtype
=
dtype
,
device
=
device
)
compiled_layer
=
torch
.
compile
(
layer
.
forward_native
)
if
provider
==
"custom"
:
fn
=
lambda
:
layer
(
x
)
elif
provider
==
"compiled"
:
fn
=
lambda
:
compiled_layer
(
x
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
fn
,
quantiles
=
[
0.5
,
0.2
,
0.8
]
)
return
ms
,
max_ms
,
min_ms
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the custom activation op."
)
parser
.
add_argument
(
"--func-name"
,
type
=
str
,
choices
=
[
"mul_and_silu"
,
"silu_and_mul"
,
"gelu_and_mul"
,
"gelu_and_mul_tanh"
,
"fatrelu_and_mul"
,
"swigluoai_and_mul"
,
"gelu_new"
,
"gelu_fast"
,
"quick_gelu"
,
],
default
=
"silu_and_mul"
,
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"bfloat16"
)
args
=
parser
.
parse_args
()
assert
args
func_name
=
args
.
func_name
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
]
perf_report
=
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
,
"intermediate_size"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"custom"
,
"compiled"
],
line_names
=
[
"Custom OP"
,
"Compiled"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"ms"
,
plot_name
=
f
"
{
func_name
}
-op-performance"
,
args
=
{},
)
)
perf_report
(
lambda
batch_size
,
seq_len
,
intermediate_size
,
provider
:
benchmark_activation
(
batch_size
,
seq_len
,
intermediate_size
,
provider
,
func_name
,
dtype
)
).
run
(
print_data
=
True
)
benchmarks/kernels/benchmark_block_fp8_gemm.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
# Disable DeepGEMM for this benchmark to use CUTLASS
os
.
environ
[
"VLLM_USE_DEEP_GEMM"
]
=
"0"
import
torch
from
vllm.benchmarks.lib.utils
import
default_vllm_config
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
W8A8BlockFp8LinearOp
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
,
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
as
vllm_triton
assert
current_platform
.
is_cuda
(),
(
"Only support benchmarking w8a8 block fp8 kernel on CUDA device."
)
# DeepSeek-V3 weight shapes
DEEPSEEK_V3_SHAPES
=
[
(
512
+
64
,
7168
),
(
2112
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
7168
,
16384
),
(
7168
,
18432
),
(
18432
*
2
,
7168
),
(
24576
,
1536
),
(
12288
,
7168
),
(
4096
,
7168
),
(
7168
,
2048
),
]
@
default_vllm_config
()
def
build_w8a8_block_fp8_runner
(
M
,
N
,
K
,
block_size
,
device
,
use_cutlass
):
"""Build runner function for w8a8 block fp8 matmul."""
factor_for_scale
=
1e-2
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
# Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp)
A_ref
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
# Create quantized weight tensor
B_ref
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
B
=
B_ref
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
# Create weight scales
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
Bs
=
(
torch
.
rand
(
n_tiles
,
k_tiles
,
dtype
=
torch
.
float32
,
device
=
device
)
*
factor_for_scale
)
# Create W8A8BlockFp8LinearOp instance
weight_group_shape
=
GroupShape
(
block_n
,
block_k
)
act_quant_group_shape
=
GroupShape
(
1
,
block_k
)
# Per-token, per-group quantization
linear_op
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
weight_group_shape
,
act_quant_group_shape
=
act_quant_group_shape
,
cutlass_block_fp8_supported
=
use_cutlass
,
use_aiter_and_is_supported
=
False
,
)
def
run
():
return
linear_op
.
apply
(
input
=
A_ref
,
weight
=
B
,
weight_scale
=
Bs
,
input_scale
=
None
,
bias
=
None
,
)
return
run
# Determine available providers
available_providers
=
[
"torch-bf16"
,
"w8a8-block-fp8-triton"
]
plot_title
=
"BF16 vs W8A8 Block FP8 GEMMs"
if
CUTLASS_BLOCK_FP8_SUPPORTED
:
available_providers
.
append
(
"w8a8-block-fp8-cutlass"
)
@
vllm_triton
.
testing
.
perf_report
(
vllm_triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
available_providers
,
line_names
=
available_providers
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs W8A8 Block FP8 GEMMs"
,
args
=
{},
)
)
def
benchmark_tflops
(
batch_size
,
provider
,
N
,
K
,
block_size
=
(
128
,
128
)):
M
=
batch_size
device
=
"cuda"
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
ms
,
min_ms
,
max_ms
=
vllm_triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
elif
provider
==
"w8a8-block-fp8-triton"
:
run_w8a8_triton
=
build_w8a8_block_fp8_runner
(
M
,
N
,
K
,
block_size
,
device
,
use_cutlass
=
False
)
ms
,
min_ms
,
max_ms
=
vllm_triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_w8a8_triton
(),
quantiles
=
quantiles
)
elif
provider
==
"w8a8-block-fp8-cutlass"
:
run_w8a8_cutlass
=
build_w8a8_block_fp8_runner
(
M
,
N
,
K
,
block_size
,
device
,
use_cutlass
=
True
)
ms
,
min_ms
,
max_ms
=
vllm_triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_w8a8_cutlass
(),
quantiles
=
quantiles
)
else
:
raise
ValueError
(
f
"Unknown provider:
{
provider
}
"
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
if
__name__
==
"__main__"
:
block_size
=
(
128
,
128
)
for
N
,
K
in
DEEPSEEK_V3_SHAPES
:
print
(
f
"
\n
Benchmarking DeepSeek-V3, N=
{
N
}
K=
{
K
}
"
)
print
(
f
"TFLOP/s comparison (block_size=
{
block_size
}
):"
)
benchmark_tflops
.
run
(
print_data
=
True
,
# show_plots=False,
# save_path=f"bench_w8a8_block_fp8_tflops_n{N}_k{K}",
N
=
N
,
K
=
K
,
block_size
=
block_size
,
)
print
(
"
\n
Benchmark finished!"
)
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark the performance of the cutlass_moe_fp8 kernel vs the triton_moe
kernel. Both kernels take in fp8 quantized weights and 16-bit activations,
but use different quantization strategies and backends.
"""
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
tests.kernels.moe.utils
import
make_dummy_moe_config
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.all2all_utils
import
(
maybe_make_prepare_finalize
,
)
from
vllm.model_executor.layers.fused_moe.config
import
fp8_w8a8_moe_quant_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassExpertsFp8
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.platforms
import
current_platform
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.v1.worker.workspace
import
init_workspace_manager
# Weight shapes for different models: [num_experts, topk, hidden_size,
# intermediate_size]
WEIGHT_SHAPES_MOE
=
{
"mixtral-8x7b"
:
[
[
8
,
2
,
4096
,
14336
],
],
"deepseek-v2"
:
[
[
160
,
6
,
5120
,
12288
],
],
"custom-small"
:
[
[
8
,
2
,
2048
,
7168
],
],
"glm45-fp8"
:
[
[
128
,
8
,
4096
,
1408
],
],
"Llama-4-Maverick-17B-128E-Instruct-FP8"
:
[
[
128
,
1
,
5120
,
8192
],
],
}
DEFAULT_MODELS
=
[
"mixtral-8x7b"
,
]
DEFAULT_BATCH_SIZES
=
[
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
]
DEFAULT_TP_SIZES
=
[
1
]
PER_ACT_TOKEN_OPTS
=
[
False
,
True
]
PER_OUT_CH_OPTS
=
[
False
,
True
]
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
def
bench_run
(
results
:
list
,
model
:
str
,
num_experts
:
int
,
topk
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
mkn
:
tuple
[
int
,
int
,
int
],
):
init_workspace_manager
(
torch
.
cuda
.
current_device
())
(
m
,
k
,
n
)
=
mkn
dtype
=
torch
.
half
device
=
"cuda"
# Create input activations
a
=
torch
.
randn
((
m
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
# Create weights
w1
=
torch
.
randn
((
num_experts
,
2
*
n
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
num_experts
,
k
,
n
),
device
=
device
,
dtype
=
dtype
)
/
10
# Create FP8 quantized weights and scales for both kernels
w1_fp8q
=
torch
.
empty
((
num_experts
,
2
*
n
,
k
),
device
=
device
,
dtype
=
FP8_DTYPE
)
w2_fp8q
=
torch
.
empty
((
num_experts
,
k
,
n
),
device
=
device
,
dtype
=
FP8_DTYPE
)
# Create scales based on quantization strategy
if
per_out_ch
:
# Per-channel quantization
w1_scale
=
torch
.
empty
(
(
num_experts
,
2
*
n
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
empty
((
num_experts
,
k
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
else
:
# Per-tensor quantization
w1_scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
# Quantize weights
for
expert
in
range
(
num_experts
):
if
per_out_ch
:
# Per-channel quantization - not yet implemented properly
# For now, fall back to per-tensor quantization
w1_fp8q
[
expert
],
w1_scale_temp
=
ops
.
scaled_fp8_quant
(
w1
[
expert
])
w2_fp8q
[
expert
],
w2_scale_temp
=
ops
.
scaled_fp8_quant
(
w2
[
expert
])
# Expand scalar scales to the expected per-channel shape
w1_scale
[
expert
]
=
w1_scale_temp
.
expand
(
2
*
n
,
1
)
w2_scale
[
expert
]
=
w2_scale_temp
.
expand
(
k
,
1
)
else
:
# Per-tensor quantization
w1_fp8q
[
expert
],
w1_scale_temp
=
ops
.
scaled_fp8_quant
(
w1
[
expert
])
w2_fp8q
[
expert
],
w2_scale_temp
=
ops
.
scaled_fp8_quant
(
w2
[
expert
])
# Store scalar scales in [1, 1] tensors
w1_scale
[
expert
,
0
,
0
]
=
w1_scale_temp
w2_scale
[
expert
,
0
,
0
]
=
w2_scale_temp
# Prepare weights for CUTLASS (no transpose needed)
w1_fp8q_cutlass
=
w1_fp8q
# Keep original [E, 2N, K]
w2_fp8q_cutlass
=
w2_fp8q
# Keep original [E, K, N]
# Create router scores and get topk
score
=
torch
.
randn
((
m
,
num_experts
),
device
=
device
,
dtype
=
dtype
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
False
)
# WORKAROUND: CUTLASS MoE FP8 has issues with per-token quantization
# Force per-tensor quantization for all cases to match working e2e setup
a1_scale
=
torch
.
full
((),
1e-2
,
device
=
device
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
full
((),
1e-2
,
device
=
device
,
dtype
=
torch
.
float32
)
# Force per-tensor quantization for all cases
per_act_token
=
False
# Pre-create quantization config to avoid creating it inside CUDA graph
quant_config
=
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
per_act_token_quant
=
per_act_token
,
per_out_ch_quant
=
per_out_ch
,
)
moe_config
=
make_dummy_moe_config
(
num_experts
=
num_experts
,
hidden_dim
=
k
,
intermediate_size_per_partition
=
n
,
in_dtype
=
a
.
dtype
,
)
fn
=
mk
.
FusedMoEKernel
(
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
quant_config
,
allow_new_interface
=
True
,
use_monolithic
=
False
,
),
CutlassExpertsFp8
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
# Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly)
cutlass_stream
=
torch
.
cuda
.
Stream
()
cutlass_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
cutlass_graph
,
stream
=
cutlass_stream
):
# Capture 10 invocations like benchmark_moe.py
for
_
in
range
(
10
):
fn
(
a
,
w1_fp8q_cutlass
,
w2_fp8q_cutlass
,
topk_weights
,
topk_ids
,
activation
=
MoEActivation
.
SILU
,
global_num_experts
=
num_experts
,
)
torch
.
cuda
.
synchronize
()
# Create CUDA graphs for Triton (match benchmark_moe.py pattern exactly)
triton_stream
=
torch
.
cuda
.
Stream
()
triton_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
triton_graph
,
stream
=
triton_stream
):
# Capture 10 invocations like benchmark_moe.py
for
_
in
range
(
10
):
fused_experts
(
a
,
w1_fp8q
,
w2_fp8q
,
topk_weights
,
topk_ids
,
quant_config
=
quant_config
,
)
torch
.
cuda
.
synchronize
()
def
bench_cuda_graph
(
graph
,
num_warmup
=
5
,
num_iters
=
100
):
"""Benchmark CUDA graph using events like benchmark_moe.py"""
# Warmup
for
_
in
range
(
num_warmup
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
# Timing
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
=
[]
for
_
in
range
(
num_iters
):
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
# Divide by 10 since graph contains 10 calls
return
sum
(
latencies
)
/
(
num_iters
*
10
)
# Benchmark parameters
num_warmup
=
5
num_iters
=
100
# Benchmark only CUDA graphs (more reliable and faster)
# Benchmark Triton MoE with CUDA graphs
triton_graph_time
=
bench_cuda_graph
(
triton_graph
,
num_warmup
=
num_warmup
,
num_iters
=
num_iters
)
# Benchmark CUTLASS MoE with CUDA graphs
cutlass_graph_time
=
bench_cuda_graph
(
cutlass_graph
,
num_warmup
=
num_warmup
,
num_iters
=
num_iters
)
# Convert ms to us and return results
triton_time_us
=
triton_graph_time
*
1000
cutlass_time_us
=
cutlass_graph_time
*
1000
return
{
"batch_size"
:
m
,
"triton_time_us"
:
triton_time_us
,
"cutlass_time_us"
:
cutlass_time_us
,
}
def
main
(
args
):
# Initialize workspace manager (required for CUTLASS MoE kernels)
device
=
torch
.
device
(
"cuda:0"
)
init_workspace_manager
(
device
)
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
all_results
=
[]
for
model
in
args
.
models
:
for
tp
in
args
.
tp_sizes
:
for
layer
in
WEIGHT_SHAPES_MOE
[
model
]:
num_experts
=
layer
[
0
]
topk
=
layer
[
1
]
size_k
=
layer
[
2
]
size_n
=
layer
[
3
]
//
tp
if
len
(
args
.
limit_k
)
>
0
and
size_k
not
in
args
.
limit_k
:
continue
if
len
(
args
.
limit_n
)
>
0
and
size_n
not
in
args
.
limit_n
:
continue
for
per_act_token
in
args
.
per_act_token_opts
:
for
per_out_ch
in
args
.
per_out_ch_opts
:
print
(
f
"
\n
===
{
model
}
, experts=
{
num_experts
}
, topk=
{
topk
}
,"
f
"per_act=
{
per_act_token
}
, per_out_ch=
{
per_out_ch
}
==="
)
config_results
=
[]
for
size_m
in
args
.
batch_sizes
:
mkn
=
(
size_m
,
size_k
,
size_n
)
result
=
bench_run
(
[],
# Not used anymore
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
,
)
if
result
:
config_results
.
append
(
result
)
# Print results table for this configuration
if
config_results
:
print
(
f
"
\n
{
'Batch Size'
:
<
12
}
"
f
"
{
'Triton (us)'
:
<
15
}
"
f
"
{
'CUTLASS (us)'
:
<
15
}
"
)
print
(
"-"
*
45
)
for
result
in
config_results
:
print
(
f
"
{
result
[
'batch_size'
]:
<
12
}
"
f
"
{
result
[
'triton_time_us'
]:
<
15.2
f
}
"
f
"
{
result
[
'cutlass_time_us'
]:
<
15.2
f
}
"
)
all_results
.
extend
(
config_results
)
print
(
f
"
\n
Total benchmarks completed:
{
len
(
all_results
)
}
"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"""Benchmark CUTLASS FP8 MOE vs Triton FP8 FUSED MOE
across specified models/shapes/batches
Example usage:
python benchmark_cutlass_moe_fp8.py
\
--model "Llama-4-Maverick-17B-128E-Instruct-FP8"
\
--tp-sizes 8
\
--batch-size 2 4 8
\
--per-act-token-opts false
\
--per-out-ch-opts false
"""
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES_MOE
.
keys
(),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--per-act-token-opts"
,
nargs
=
"+"
,
type
=
lambda
x
:
x
.
lower
()
==
"true"
,
default
=
[
False
,
True
],
help
=
"Per-activation token quantization options (true/false)"
,
)
parser
.
add_argument
(
"--per-out-ch-opts"
,
nargs
=
"+"
,
type
=
lambda
x
:
x
.
lower
()
==
"true"
,
default
=
[
False
,
True
],
help
=
"Per-output channel quantization options (true/false)"
,
)
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe
kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit
activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8)
and 16-bit activations.
"""
import
nvtx
import
torch
import
torch.utils.benchmark
as
benchmark
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
tests.kernels.moe.utils
import
make_dummy_moe_config
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.all2all_utils
import
(
maybe_make_prepare_finalize
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
fp8_w8a8_moe_quant_config
,
nvfp4_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
CutlassExpertsFp4
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.scalar_type
import
scalar_types
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.v1.worker.workspace
import
init_workspace_manager
WEIGHT_SHAPES_MOE
=
{
"nvidia/DeepSeek-R1-FP4"
:
[
[
256
,
8
,
2048
,
7168
],
],
}
DEFAULT_MODELS
=
[
"nvidia/DeepSeek-R1-FP4"
,
]
DEFAULT_BATCH_SIZES
=
[
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
]
DEFAULT_TP_SIZES
=
[
1
]
PER_ACT_TOKEN_OPTS
=
[
False
]
PER_OUT_CH_OPTS
=
[
False
]
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
def
to_fp8
(
tensor
:
torch
.
Tensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
bench_run
(
results
:
list
[
benchmark
.
Measurement
],
model
:
str
,
num_experts
:
int
,
topk
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
mkn
:
tuple
[
int
,
int
,
int
],
):
label
=
"NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton"
sub_label
=
(
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})"
.
format
(
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
)
)
print
(
f
"Testing:
{
sub_label
}
"
)
(
m
,
k
,
n
)
=
mkn
dtype
=
torch
.
half
device
=
"cuda"
a
=
torch
.
randn
((
m
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
num_experts
,
2
*
n
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
num_experts
,
k
,
n
),
device
=
device
,
dtype
=
dtype
)
/
10
_
,
a_fp8_scale
=
ops
.
scaled_fp8_quant
(
a
)
w1_fp8q
=
torch
.
empty
(
(
num_experts
,
2
*
n
,
k
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
w2_fp8q
=
torch
.
empty
((
num_experts
,
k
,
n
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
w1_fp8scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
w2_fp8scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
num_experts
):
w1_fp8q
[
expert
],
w1_fp8scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w1
[
expert
])
w2_fp8q
[
expert
],
w2_fp8scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w2
[
expert
])
w1_fp8q_notransp
=
w1_fp8q
.
clone
()
w2_fp8q_notransp
=
w2_fp8q
.
clone
()
w1_fp8q
=
w1_fp8q
.
transpose
(
1
,
2
)
w2_fp8q
=
w2_fp8q
.
transpose
(
1
,
2
)
score
=
torch
.
randn
((
m
,
num_experts
),
device
=
device
,
dtype
=
dtype
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
False
)
quant_blocksize
=
16
w1_blockscale
=
torch
.
empty
(
(
num_experts
,
2
*
n
,
k
//
quant_blocksize
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
,
)
w2_blockscale
=
torch
.
empty
(
(
num_experts
,
k
,
n
//
quant_blocksize
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
# n_b_scales = 2 * n if per_out_ch else 1
# k_b_scales = k if per_out_ch else 1
w1_fp4
=
torch
.
empty
((
num_experts
,
2
*
n
,
k
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
w2_fp4
=
torch
.
empty
((
num_experts
,
k
,
n
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
w1_gs
=
torch
.
empty
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
float32
)
w2_gs
=
torch
.
empty
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
float32
)
a1_gs
=
torch
.
ones
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
float32
)
a2_gs
=
torch
.
ones
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
num_experts
):
w1_e
=
w1
[
expert
]
w2_e
=
w2
[
expert
]
w1_amax
=
torch
.
abs
(
w1_e
).
max
().
to
(
torch
.
float32
)
w2_amax
=
torch
.
abs
(
w2_e
).
max
().
to
(
torch
.
float32
)
w1_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w1_amax
w2_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w2_amax
w1_fp4
[
expert
],
w1_blockscale
[
expert
]
=
ops
.
scaled_fp4_quant
(
w1_e
,
w1_gs
[
expert
]
)
w2_fp4
[
expert
],
w2_blockscale
[
expert
]
=
ops
.
scaled_fp4_quant
(
w2_e
,
w2_gs
[
expert
]
)
def
run_triton_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_fp8_scale
:
torch
.
Tensor
,
num_repeats
:
int
,
):
quant_config
=
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_fp8_scale
,
)
for
_
in
range
(
num_repeats
):
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
quant_config
=
quant_config
,
)
def
run_cutlass_moe_fp4
(
a
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w1_gs
:
torch
.
Tensor
,
w2_gs
:
torch
.
Tensor
,
a1_gs
:
torch
.
Tensor
,
a2_gs
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
,
num_repeats
:
int
,
):
quant_config
=
nvfp4_moe_quant_config
(
a1_gscale
=
a1_gs
,
a2_gscale
=
a2_gs
,
w1_scale
=
w1_blockscale
,
w2_scale
=
w2_blockscale
,
g1_alphas
=
w1_gs
,
g2_alphas
=
w2_gs
,
)
moe_config
=
make_dummy_moe_config
(
num_experts
=
num_experts
,
hidden_dim
=
k
,
intermediate_size_per_partition
=
n
,
in_dtype
=
a
.
dtype
,
)
kernel
=
mk
.
FusedMoEKernel
(
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
quant_config
,
allow_new_interface
=
True
,
use_monolithic
=
False
,
),
CutlassExpertsFp4
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
for
_
in
range
(
num_repeats
):
with
nvtx
.
annotate
(
"cutlass_moe_fp4"
,
color
=
"green"
):
kernel
(
hidden_states
=
a
,
w1
=
w1_fp4
,
w2
=
w2_fp4
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
)
def
run_cutlass_from_graph
(
a
:
torch
.
Tensor
,
a1_gscale
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_alphas
:
torch
.
Tensor
,
a2_gscale
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w2_alphas
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
,
):
quant_config
=
nvfp4_moe_quant_config
(
a1_gscale
=
a1_gs
,
a2_gscale
=
a2_gs
,
w1_scale
=
w1_blockscale
,
w2_scale
=
w2_blockscale
,
g1_alphas
=
w1_gs
,
g2_alphas
=
w2_gs
,
)
moe_config
=
make_dummy_moe_config
()
kernel
=
mk
.
FusedMoEKernel
(
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
quant_config
,
allow_new_interface
=
True
,
use_monolithic
=
False
,
),
CutlassExpertsFp4
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
):
return
kernel
(
hidden_states
=
a
,
w1
=
w1_fp4
,
w2
=
w2_fp4
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
)
def
run_triton_from_graph
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_fp8_scale
:
torch
.
Tensor
,
):
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
):
quant_config
=
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_fp8_scale
,
)
return
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
quant_config
=
quant_config
,
)
def
replay_graph
(
graph
,
num_repeats
):
for
_
in
range
(
num_repeats
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
cutlass_stream
=
torch
.
cuda
.
Stream
()
cutlass_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
cutlass_graph
,
stream
=
cutlass_stream
):
run_cutlass_from_graph
(
a
=
a
,
a1_gscale
=
a1_gs
,
w1_fp4
=
w1_fp4
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
w1_gs
,
a2_gscale
=
a2_gs
,
w2_fp4
=
w2_fp4
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
w2_gs
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
num_experts
,
device
=
device
,
)
torch
.
cuda
.
synchronize
()
triton_stream
=
torch
.
cuda
.
Stream
()
triton_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
triton_graph
,
stream
=
triton_stream
):
run_triton_from_graph
(
a
,
w1_fp8q_notransp
,
w2_fp8q_notransp
,
topk_weights
,
topk_ids
,
w1_fp8scale
,
w2_fp8scale
,
a_fp8_scale
,
)
torch
.
cuda
.
synchronize
()
min_run_time
=
5
num_warmup
=
5
num_runs
=
25
globals
=
{
# Baseline params
"w1"
:
w1
,
"w2"
:
w2
,
"score"
:
score
,
"topk"
:
topk
,
"w1_fp8q_notransp"
:
w1_fp8q_notransp
,
"w2_fp8q_notransp"
:
w2_fp8q_notransp
,
"w1_fp8scale"
:
w1_fp8scale
,
"w2_fp8scale"
:
w2_fp8scale
,
"a_fp8_scale"
:
a_fp8_scale
,
# Cutlass params
"a"
:
a
,
"a1_gscale"
:
a1_gs
,
"w1_fp4"
:
w1_fp4
,
"w1_blockscale"
:
w1_blockscale
,
"w1_alphas"
:
w1_gs
,
"a2_gscale"
:
a2_gs
,
"w2_fp4"
:
w2_fp4
,
"w2_blockscale"
:
w2_blockscale
,
"w2_alphas"
:
w2_gs
,
"topk_weights"
:
topk_weights
,
"topk_ids"
:
topk_ids
,
"m"
:
m
,
"n"
:
n
,
"k"
:
k
,
"e"
:
num_experts
,
"device"
:
device
,
# cuda graph params
"cutlass_graph"
:
cutlass_graph
,
"triton_graph"
:
triton_graph
,
# Gen params
"num_runs"
:
num_runs
,
# Kernels
"run_triton_moe"
:
run_triton_moe
,
"run_cutlass_moe_fp4"
:
run_cutlass_moe_fp4
,
"replay_graph"
:
replay_graph
,
}
# Warmup
run_triton_moe
(
a
,
w1_fp8q_notransp
,
w2_fp8q_notransp
,
topk_weights
,
topk_ids
,
w1_fp8scale
,
w2_fp8scale
,
a_fp8_scale
,
num_warmup
,
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
replay_graph
(
triton_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(triton_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
run_cutlass_moe_fp4
(
a
,
w1_fp4
,
w2_fp4
,
w1_blockscale
,
w2_blockscale
,
w1_gs
,
w2_gs
,
a1_gs
,
a2_gs
,
topk_weights
,
topk_ids
,
m
,
n
,
k
,
num_experts
,
device
,
num_warmup
,
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"cutlass_moe_fp4"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
replay_graph
(
cutlass_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(cutlass_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"cutlass_moe_fp4_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
def
main
(
args
):
# Initialize workspace manager (required for CUTLASS MoE kernels)
device
=
torch
.
device
(
"cuda:0"
)
init_workspace_manager
(
device
)
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
results
:
list
[
benchmark
.
Measurement
]
=
[]
for
model
in
args
.
models
:
for
tp
in
args
.
tp_sizes
:
for
layer
in
WEIGHT_SHAPES_MOE
[
model
]:
num_experts
=
layer
[
0
]
topk
=
layer
[
1
]
size_k
=
layer
[
2
]
size_n
=
layer
[
3
]
//
tp
if
len
(
args
.
limit_k
)
>
0
and
size_k
not
in
args
.
limit_k
:
continue
if
len
(
args
.
limit_n
)
>
0
and
size_n
not
in
args
.
limit_n
:
continue
for
per_act_token
in
PER_ACT_TOKEN_OPTS
:
for
per_out_ch
in
PER_OUT_CH_OPTS
:
for
size_m
in
args
.
batch_sizes
:
mkn
=
(
size_m
,
size_k
,
size_n
)
bench_run
(
results
,
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
,
)
compare
=
benchmark
.
Compare
(
results
)
compare
.
print
()
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES_MOE
.
keys
(),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-num-groups"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-act-token"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-out-ch"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_device_communicators.py
0 → 100644
View file @
fbeb8a6f
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark script for device communicators:
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
and SymmMemCommunicator (multimem, two-shot).
for NCCL symmetric memory you need to set the environment variables
NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does
not use fast NVLS implementation for all reduce.
Usage:
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
Example:
torchrun --nproc_per_node=2 benchmark_device_communicators.py
--sequence-lengths 512 1024 2048 --num-warmup 10 --num-trials 100
"""
import
json
import
os
import
time
from
collections.abc
import
Callable
from
contextlib
import
nullcontext
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
vllm.distributed.device_communicators.custom_all_reduce
import
CustomAllreduce
from
vllm.distributed.device_communicators.flashinfer_all_reduce
import
(
FlashInferAllReduce
,
)
from
vllm.distributed.device_communicators.pynccl
import
(
PyNcclCommunicator
,
register_nccl_symmetric_ops
,
)
from
vllm.distributed.device_communicators.pynccl_allocator
import
(
set_graph_pool_id
,
)
from
vllm.distributed.device_communicators.symm_mem
import
SymmMemCommunicator
from
vllm.logger
import
init_logger
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
logger
=
init_logger
(
__name__
)
# Default sequence lengths to benchmark
DEFAULT_SEQUENCE_LENGTHS
=
[
16
,
64
,
128
,
512
,
1024
,
2048
,
4096
,
8192
]
# Fixed hidden size and dtype for all benchmarks
HIDDEN_SIZE
=
8192
BENCHMARK_DTYPE
=
torch
.
bfloat16
# CUDA graph settings
CUDA_GRAPH_CAPTURE_CYCLES
=
10
class
CommunicatorBenchmark
:
"""Benchmark class for testing device communicators."""
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
device
:
torch
.
device
,
cpu_group
:
ProcessGroup
,
sequence_lengths
:
list
[
int
],
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
device
=
device
self
.
cpu_group
=
cpu_group
# Calculate max_size_override based on largest sequence length
max_seq_len
=
max
(
sequence_lengths
)
max_tensor_elements
=
max_seq_len
*
HIDDEN_SIZE
self
.
max_size_override
=
max_tensor_elements
*
BENCHMARK_DTYPE
.
itemsize
+
1
# Initialize communicators
self
.
custom_allreduce
=
None
self
.
pynccl_comm
=
None
self
.
symm_mem_comm
=
None
self
.
symm_mem_comm_multimem
=
None
self
.
symm_mem_comm_two_shot
=
None
self
.
fi_ar_comm
=
None
self
.
_init_communicators
()
def
_init_communicators
(
self
):
"""Initialize all available communicators."""
try
:
self
.
custom_allreduce
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
max_size
=
self
.
max_size_override
,
)
if
not
self
.
custom_allreduce
.
disabled
:
logger
.
info
(
"Rank %s: CustomAllreduce initialized"
,
self
.
rank
)
else
:
logger
.
info
(
"Rank %s: CustomAllreduce disabled"
,
self
.
rank
)
except
Exception
as
e
:
logger
.
warning
(
"Rank %s: Failed to initialize CustomAllreduce: %s"
,
self
.
rank
,
e
)
self
.
custom_allreduce
=
None
try
:
self
.
pynccl_comm
=
PyNcclCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
)
if
not
self
.
pynccl_comm
.
disabled
:
logger
.
info
(
"Rank %s: PyNcclCommunicator initialized"
,
self
.
rank
)
register_nccl_symmetric_ops
(
self
.
pynccl_comm
)
else
:
logger
.
info
(
"Rank %s: PyNcclCommunicator disabled"
,
self
.
rank
)
self
.
pynccl_comm
=
None
except
Exception
as
e
:
logger
.
warning
(
"Rank %s: Failed to initialize PyNcclCommunicator: %s"
,
self
.
rank
,
e
)
self
.
pynccl_comm
=
None
# Initialize variants for SymmMemCommunicator
try
:
self
.
symm_mem_comm_multimem
=
SymmMemCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
force_multimem
=
True
,
max_size_override
=
self
.
max_size_override
,
)
if
not
self
.
symm_mem_comm_multimem
.
disabled
:
logger
.
info
(
"Rank %s: SymmMemCommunicator (multimem) initialized"
,
self
.
rank
)
else
:
self
.
symm_mem_comm_multimem
=
None
except
Exception
as
e
:
logger
.
warning
(
"Rank %s: Failed to initialize SymmMemCommunicator (multimem): %s"
,
self
.
rank
,
e
,
)
self
.
symm_mem_comm_multimem
=
None
try
:
self
.
symm_mem_comm_two_shot
=
SymmMemCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
force_multimem
=
False
,
max_size_override
=
self
.
max_size_override
,
)
if
not
self
.
symm_mem_comm_two_shot
.
disabled
:
logger
.
info
(
"Rank %s: SymmMemCommunicator (two_shot) initialized"
,
self
.
rank
)
else
:
self
.
symm_mem_comm_two_shot
=
None
except
Exception
as
e
:
logger
.
warning
(
"Rank %s: Failed to initialize SymmMemCommunicator (two_shot): %s"
,
self
.
rank
,
e
,
)
self
.
symm_mem_comm_two_shot
=
None
try
:
self
.
fi_ar_comm
=
FlashInferAllReduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
if
not
self
.
fi_ar_comm
.
disabled
:
logger
.
info
(
"Rank %s: FlashInferAllReduce initialized"
,
self
.
rank
)
else
:
logger
.
info
(
"Rank %s: FlashInferAllReduce disabled"
,
self
.
rank
)
self
.
fi_ar_comm
=
None
except
Exception
as
e
:
logger
.
warning
(
"Rank %s: Failed to initialize FlashInferAllReduce: %s"
,
self
.
rank
,
e
)
self
.
fi_ar_comm
=
None
def
benchmark_allreduce
(
self
,
sequence_length
:
int
,
num_warmup
:
int
,
num_trials
:
int
)
->
dict
[
str
,
float
]:
"""Benchmark allreduce operations for all available communicators."""
results
=
{}
# Define communicators with their benchmark functions
communicators
=
[]
if
self
.
custom_allreduce
is
not
None
:
comm
=
self
.
custom_allreduce
# CustomAllreduce one-shot
communicators
.
append
(
(
"ca_1stage"
,
lambda
t
,
c
=
comm
:
c
.
custom_all_reduce
(
t
),
lambda
t
,
c
=
comm
:
c
.
should_custom_ar
(
t
),
comm
.
capture
(),
{
"VLLM_CUSTOM_ALLREDUCE_ALGO"
:
"1stage"
},
None
,
# no destroy function
)
)
# CustomAllreduce two-shot
communicators
.
append
(
(
"ca_2stage"
,
lambda
t
,
c
=
comm
:
c
.
custom_all_reduce
(
t
),
lambda
t
,
c
=
comm
:
c
.
should_custom_ar
(
t
),
comm
.
capture
(),
{
"VLLM_CUSTOM_ALLREDUCE_ALGO"
:
"2stage"
},
None
,
# no destroy function
)
)
if
self
.
pynccl_comm
is
not
None
:
comm
=
self
.
pynccl_comm
communicators
.
append
(
(
"pynccl"
,
lambda
t
,
c
=
comm
:
c
.
all_reduce
(
t
),
lambda
t
:
True
,
# Always available if initialized
nullcontext
(),
{},
# no env variable needed
None
,
# no destroy function
)
)
communicators
.
append
(
(
"pynccl-symm"
,
lambda
t
:
torch
.
ops
.
vllm
.
all_reduce_symmetric_with_copy
(
t
),
lambda
t
:
True
,
# Always available if initialized
nullcontext
(),
{},
# no env variable needed
None
,
# no destroy function
)
)
if
self
.
symm_mem_comm_multimem
is
not
None
:
comm
=
self
.
symm_mem_comm_multimem
communicators
.
append
(
(
"symm_mem_multimem"
,
lambda
t
,
c
=
comm
:
c
.
all_reduce
(
t
),
lambda
t
,
c
=
comm
:
c
.
should_use_symm_mem
(
t
),
nullcontext
(),
{},
# no env variable needed
None
,
# no destroy function
)
)
if
self
.
symm_mem_comm_two_shot
is
not
None
:
comm
=
self
.
symm_mem_comm_two_shot
communicators
.
append
(
(
"symm_mem_two_shot"
,
lambda
t
,
c
=
comm
:
c
.
all_reduce
(
t
),
lambda
t
,
c
=
comm
:
c
.
should_use_symm_mem
(
t
),
nullcontext
(),
{},
# no env variable needed
None
,
# no destroy function needed
)
)
if
self
.
fi_ar_comm
is
not
None
:
comm
=
self
.
fi_ar_comm
communicators
.
append
(
(
"flashinfer_trtllm"
,
lambda
t
,
c
=
comm
:
c
.
all_reduce
(
t
),
lambda
t
,
c
=
comm
:
c
.
should_use_fi_ar
(
t
),
nullcontext
(),
{
"VLLM_FLASHINFER_ALLREDUCE_BACKEND"
:
"trtllm"
},
lambda
c
=
comm
:
c
.
destroy
(),
)
)
communicators
.
append
(
(
"flashinfer_mnnvl"
,
lambda
t
,
c
=
comm
:
c
.
all_reduce
(
t
),
lambda
t
,
c
=
comm
:
c
.
should_use_fi_ar
(
t
),
nullcontext
(),
{
"VLLM_FLASHINFER_ALLREDUCE_BACKEND"
:
"mnnvl"
},
lambda
c
=
comm
:
c
.
destroy
(),
)
)
# Benchmark each communicator
for
(
name
,
allreduce_fn
,
should_use_fn
,
context
,
env_dict
,
destroy_fn
,
)
in
communicators
:
# Save original values and apply new environment variables
saved_env
=
{
key
:
os
.
environ
.
get
(
key
)
for
key
in
env_dict
}
for
key
,
value
in
env_dict
.
items
():
os
.
environ
[
key
]
=
value
try
:
latency
=
self
.
benchmark_allreduce_single
(
sequence_length
,
allreduce_fn
,
should_use_fn
,
context
,
num_warmup
,
num_trials
,
)
if
latency
is
not
None
:
results
[
name
]
=
latency
finally
:
if
destroy_fn
is
not
None
:
destroy_fn
()
# Restore environment variables to their original state
for
key
,
original_value
in
saved_env
.
items
():
if
original_value
is
None
:
os
.
environ
.
pop
(
key
,
None
)
else
:
os
.
environ
[
key
]
=
original_value
return
results
def
benchmark_allreduce_single
(
self
,
sequence_length
:
int
,
allreduce_fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
|
None
],
should_use_fn
:
Callable
[[
torch
.
Tensor
],
bool
],
context
,
num_warmup
:
int
,
num_trials
:
int
,
)
->
float
|
None
:
"""Benchmark method with CUDA graph optimization."""
try
:
# Create test tensor (2D: sequence_length x hidden_size)
tensor
=
torch
.
randn
(
sequence_length
,
HIDDEN_SIZE
,
dtype
=
BENCHMARK_DTYPE
,
device
=
self
.
device
)
if
not
should_use_fn
(
tensor
):
return
None
torch
.
cuda
.
synchronize
()
stream
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
stream
(
stream
):
graph_input
=
tensor
.
clone
()
# Warmup before capture
for
_
in
range
(
3
):
allreduce_fn
(
graph_input
)
# Capture the graph using context manager
with
context
:
graph
=
torch
.
cuda
.
CUDAGraph
()
graph_pool
=
torch
.
cuda
.
graph_pool_handle
()
set_graph_pool_id
(
graph_pool
)
with
torch
.
cuda
.
graph
(
graph
,
pool
=
graph_pool
,
stream
=
stream
):
for
_
in
range
(
CUDA_GRAPH_CAPTURE_CYCLES
):
allreduce_fn
(
graph_input
)
torch
.
cuda
.
synchronize
()
for
_
in
range
(
num_warmup
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
num_trials
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
# Convert to ms and divide by CUDA_GRAPH_CAPTURE_CYCLES
return
(
(
end_time
-
start_time
)
/
num_trials
/
CUDA_GRAPH_CAPTURE_CYCLES
*
1000
)
except
Exception
as
e
:
logger
.
error
(
"CUDA graph benchmark failed: %s"
,
e
)
raise
RuntimeError
(
f
"CUDA graph benchmark failed for communicator:
{
e
}
"
)
from
e
def
_calculate_speedup_info
(
comm_results
:
dict
[
str
,
float
])
->
str
:
"""Calculate speedup information for a single tensor size."""
if
not
comm_results
:
return
"N/A"
# Find the fastest communicator
fastest_comm
=
min
(
comm_results
.
keys
(),
key
=
lambda
k
:
comm_results
[
k
])
fastest_time
=
comm_results
[
fastest_comm
]
# Calculate speedup vs PyNccl if available
if
"pynccl"
in
comm_results
:
pynccl_time
=
comm_results
[
"pynccl"
]
speedup
=
pynccl_time
/
fastest_time
return
f
"
{
fastest_comm
}
(
{
speedup
:.
2
f
}
x)"
else
:
return
f
"
{
fastest_comm
}
(N/A)"
def
print_results
(
results
:
dict
[
str
,
dict
[
str
,
float
]],
sequence_lengths
:
list
[
int
],
world_size
:
int
):
"""Print benchmark results in a formatted table."""
print
(
f
"
\n
{
'='
*
130
}
"
)
print
(
"Device Communicator Benchmark Results"
)
print
(
f
"World Size:
{
world_size
}
, Data Type:
{
BENCHMARK_DTYPE
}
, "
f
"Hidden Size:
{
HIDDEN_SIZE
}
"
)
print
(
f
"
{
'='
*
130
}
"
)
# Get all communicator names
all_comms
=
set
()
for
size_results
in
results
.
values
():
all_comms
.
update
(
size_results
.
keys
())
all_comms
=
sorted
(
list
(
all_comms
))
# Print header
header
=
f
"
{
'Tensor Shape'
:
<
20
}{
'Tensor Size'
:
<
15
}
"
for
comm
in
all_comms
:
header
+=
f
"
{
comm
:
<
20
}
"
header
+=
f
"
{
'Best (Speedup vs PyNccl)'
:
<
30
}
"
print
(
header
)
print
(
"-"
*
len
(
header
))
# Print results for each sequence length
for
seq_len
in
sequence_lengths
:
if
seq_len
in
results
:
# Calculate tensor size in elements and bytes
tensor_elements
=
seq_len
*
HIDDEN_SIZE
tensor_bytes
=
tensor_elements
*
BENCHMARK_DTYPE
.
itemsize
# Format tensor size (MB)
tensor_size_mb
=
tensor_bytes
/
(
1024
*
1024
)
tensor_size_str
=
f
"
{
tensor_size_mb
:.
2
f
}
MB"
# Format tensor shape
tensor_shape
=
f
"(
{
seq_len
}
,
{
HIDDEN_SIZE
}
)"
row
=
f
"
{
tensor_shape
:
<
20
}{
tensor_size_str
:
<
15
}
"
for
comm
in
all_comms
:
if
comm
in
results
[
seq_len
]:
row
+=
f
"
{
results
[
seq_len
][
comm
]:
<
20.3
f
}
"
else
:
row
+=
f
"
{
'N/A'
:
<
20
}
"
# Calculate speedup information
speedup_info
=
_calculate_speedup_info
(
results
[
seq_len
])
row
+=
f
"
{
speedup_info
:
<
30
}
"
print
(
row
)
print
(
f
"
{
'='
*
130
}
"
)
print
(
"All times are in milliseconds (ms) per allreduce operation"
)
print
(
"Speedup column shows: fastest_algorithm (speedup_vs_pynccl)"
)
def
main
():
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark device communicators"
)
parser
.
add_argument
(
"--sequence-lengths"
,
type
=
int
,
nargs
=
"+"
,
default
=
DEFAULT_SEQUENCE_LENGTHS
,
help
=
"Sequence lengths to benchmark (tensor shape: seq_len x hidden_size)"
,
)
parser
.
add_argument
(
"--num-warmup"
,
type
=
int
,
default
=
5
,
help
=
"Number of warmup iterations"
)
parser
.
add_argument
(
"--num-trials"
,
type
=
int
,
default
=
50
,
help
=
"Number of benchmark trials"
)
parser
.
add_argument
(
"--output-json"
,
type
=
str
,
help
=
"Output results to JSON file"
)
args
=
parser
.
parse_args
()
# Initialize distributed
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
"gloo"
)
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
# Set device
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
# Get CPU process group
cpu_group
=
dist
.
new_group
(
backend
=
"gloo"
)
# Disable USE_SYMM_MEM to avoid affecting the max_sizes
# in symm_mem and custom_all_reduce for benchmark
os
.
environ
[
"VLLM_ALLREDUCE_USE_SYMM_MEM"
]
=
"0"
# Initialize benchmark
benchmark
=
CommunicatorBenchmark
(
rank
,
world_size
,
device
,
cpu_group
,
args
.
sequence_lengths
)
# Run benchmarks
all_results
=
{}
for
seq_len
in
args
.
sequence_lengths
:
if
rank
==
0
:
logger
.
info
(
"Benchmarking sequence length: %s (tensor shape: %s x %s)"
,
seq_len
,
seq_len
,
HIDDEN_SIZE
,
)
results
=
benchmark
.
benchmark_allreduce
(
sequence_length
=
seq_len
,
num_warmup
=
args
.
num_warmup
,
num_trials
=
args
.
num_trials
,
)
all_results
[
seq_len
]
=
results
# Synchronize between ranks
dist
.
barrier
()
# Print results (only rank 0)
if
rank
==
0
:
print_results
(
all_results
,
args
.
sequence_lengths
,
world_size
)
# Save to JSON if requested
if
args
.
output_json
:
# Add speedup information to results
enhanced_results
=
{}
for
seq_len
,
comm_results
in
all_results
.
items
():
enhanced_results
[
seq_len
]
=
{
"timings"
:
comm_results
,
"speedup_info"
:
_calculate_speedup_info
(
comm_results
),
}
output_data
=
{
"world_size"
:
world_size
,
"dtype"
:
str
(
BENCHMARK_DTYPE
),
"hidden_size"
:
HIDDEN_SIZE
,
"sequence_lengths"
:
args
.
sequence_lengths
,
"num_warmup"
:
args
.
num_warmup
,
"num_trials"
:
args
.
num_trials
,
"cuda_graph_capture_cycles"
:
CUDA_GRAPH_CAPTURE_CYCLES
,
"results"
:
enhanced_results
,
}
with
open
(
args
.
output_json
,
"w"
)
as
f
:
json
.
dump
(
output_data
,
f
,
indent
=
2
)
logger
.
info
(
"Results saved to %s"
,
args
.
output_json
)
# Cleanup
if
cpu_group
!=
dist
.
group
.
WORLD
:
dist
.
destroy_process_group
(
cpu_group
)
if
__name__
==
"__main__"
:
main
()
benchmarks/kernels/benchmark_fp8_gemm.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_fp8_quant
as
vllm_scaled_fp8_quant
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"fp8-tensor-w-token-a"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
False
),
"fp8-tensor-w-tensor-a"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
True
),
"fp8-channel-w-token-a"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
True
),
"fp8-channel-w-tensor-a"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
False
),
"fp8-tensor-w-token-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
False
),
"fp8-tensor-w-tensor-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
True
),
"fp8-channel-w-token-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
True
),
"fp8-channel-w-tensor-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
False
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
_quant_weight_fp8
(
b
:
torch
.
Tensor
,
w_type
:
str
,
device
:
str
):
if
w_type
==
"tensor"
:
scale_b
=
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
else
:
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
use_per_token_if_dynamic
=
True
)
return
b_fp8
.
t
(),
scale_b_fp8
def
build_fp8_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
b_fp8
,
scale_b_fp8
=
_quant_weight_fp8
(
b
,
cfg
[
"w"
],
device
)
scale_a_const
=
(
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
if
cfg
[
"a"
]
==
"tensor"
else
None
)
if
cfg
[
"no_a_quant"
]:
if
cfg
[
"a"
]
==
"tensor"
:
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a_const
)
else
:
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
def
run
():
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
return
run
if
cfg
[
"a"
]
==
"tensor"
:
def
run
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a_const
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
else
:
def
run
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs FP8 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_fp8_runner
(
cfg
,
a
,
b
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs FP8 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_fp8_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
benchmarks/kernels/benchmark_fused_collective.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark for FlashInfer fused collective operations vs standard operations.
This benchmark compares:
1. FlashInfer's allreduce_fusion with trtllm backend
(fused allreduce + rmsnorm + optional FP8/FP4 quant)
2. FlashInfer's allreduce_fusion with mnnvl backend
(fused allreduce + rmsnorm only, no quantization support)
3. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
Usage with torchrun:
torchrun --nproc_per_node=2 benchmark_fused_collective.py
"""
import
argparse
import
itertools
import
os
import
time
import
pandas
as
pd
import
torch
# type: ignore
import
torch.distributed
as
dist
# type: ignore
from
vllm.config.vllm
import
CompilationConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.distributed
import
(
tensor_model_parallel_all_reduce
,
)
from
vllm.distributed.parallel_state
import
(
graph_capture
,
init_distributed_environment
,
initialize_model_parallel
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
# noqa
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
# noqa
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
# noqa
from
vllm.platforms
import
current_platform
# noqa
RMS_NORM_OP
=
torch
.
ops
.
_C
.
rms_norm
FUSED_ADD_RMS_NORM_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
RMS_NORM_STATIC_FP8_QUANT_OP
=
torch
.
ops
.
_C
.
rms_norm_static_fp8_quant
FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP
=
(
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
)
SCALED_FP4_QUANT_OP
=
torch
.
ops
.
_C
.
scaled_fp4_quant
logger
=
init_logger
(
__name__
)
# Try to import FlashInfer
TorchDistBackend
=
None
try
:
import
flashinfer.comm
as
flashinfer_comm
# type: ignore
from
flashinfer.comm.mnnvl
import
(
# type: ignore
TorchDistBackend
,
)
if
not
(
hasattr
(
flashinfer_comm
,
"allreduce_fusion"
)
and
hasattr
(
flashinfer_comm
,
"create_allreduce_fusion_workspace"
)
):
flashinfer_comm
=
None
logger
.
warning
(
"FlashInfer comm module found but missing allreduce_fusion API"
)
except
ImportError
:
flashinfer_comm
=
None
logger
.
warning
(
"FlashInfer not found, only benchmarking standard operations"
)
# Constants
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
MiB
=
1024
*
1024
# FlashInfer max sizes per world size
# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes
# use --disable-oneshot to disable oneshot mode for very large input sizes
_FI_MAX_SIZES
=
{
2
:
64
*
MiB
,
# 64MB
4
:
64
*
MiB
,
# 64MB
8
:
64
*
MiB
,
# 64MB
}
# Global workspace tensors for FlashInfer (keyed by backend name)
_FI_WORKSPACES
:
dict
=
{}
# Backends to benchmark
FLASHINFER_BACKENDS
=
[
"trtllm"
,
"mnnvl"
]
def
setup_flashinfer_workspace
(
backend
:
str
,
world_size
:
int
,
rank
:
int
,
hidden_dim
:
int
,
max_token_num
:
int
,
dtype
:
torch
.
dtype
,
):
"""Setup FlashInfer workspace for fused allreduce operations."""
global
FI_WORKSPACES
if
flashinfer_comm
is
None
:
return
None
if
world_size
not
in
_FI_MAX_SIZES
:
logger
.
warning
(
"FlashInfer not supported for world size %s"
,
world_size
)
return
None
try
:
kwargs
=
{}
if
TorchDistBackend
is
not
None
:
kwargs
[
"comm_backend"
]
=
TorchDistBackend
(
group
=
dist
.
group
.
WORLD
)
workspace
=
flashinfer_comm
.
create_allreduce_fusion_workspace
(
backend
=
backend
,
world_size
=
world_size
,
rank
=
rank
,
max_token_num
=
max_token_num
,
hidden_dim
=
hidden_dim
,
dtype
=
dtype
,
**
kwargs
,
)
_FI_WORKSPACES
[
backend
]
=
workspace
return
workspace
except
Exception
as
e
:
logger
.
error
(
"Failed to setup FlashInfer workspace (backend=%s): %s"
,
backend
,
e
)
return
None
def
cleanup_flashinfer_workspaces
():
"""Cleanup all FlashInfer workspaces."""
if
flashinfer_comm
is
None
:
return
for
backend
,
workspace
in
_FI_WORKSPACES
.
items
():
try
:
workspace
.
destroy
()
except
Exception
as
e
:
logger
.
error
(
"Failed to cleanup FlashInfer workspace (backend=%s): %s"
,
backend
,
e
,
)
_FI_WORKSPACES
.
clear
()
class
FlashInferFusedAllReduceParams
:
"""Parameters for FlashInfer fused allreduce operations."""
def
__init__
(
self
,
max_token_num
:
int
=
1024
,
):
self
.
launch_with_pdl
=
True
self
.
fp32_acc
=
True
self
.
max_token_num
=
max_token_num
def
get_flashinfer_fused_allreduce_kwargs
(
self
):
return
{
"launch_with_pdl"
:
self
.
launch_with_pdl
,
"fp32_acc"
:
self
.
fp32_acc
,
}
def
flashinfer_fused_allreduce_rmsnorm
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
allreduce_params
:
"FlashInferFusedAllReduceParams"
,
workspace
:
object
,
use_oneshot
:
bool
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm operation."""
if
flashinfer_comm
is
None
or
workspace
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
norm_out
=
input_tensor
residual_out
=
residual
else
:
residual_out
=
input_tensor
layout_code
=
None
if
workspace
.
backend
==
"trtllm"
:
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
flashinfer_comm
.
allreduce_fusion
(
input
=
input_tensor
,
workspace
=
workspace
,
pattern
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNorm
,
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
quant_out
=
None
,
scale_out
=
None
,
layout_code
=
layout_code
,
scale_factor
=
None
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_flashinfer_fused_allreduce_kwargs
(),
)
def
flashinfer_fused_allreduce_rmsnorm_fp8_quant
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
scale_factor
:
torch
.
Tensor
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
workspace
:
object
,
use_oneshot
:
bool
=
True
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
quant_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization.
Note: Only supported by the trtllm backend.
"""
if
flashinfer_comm
is
None
or
workspace
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
norm_out
=
input_tensor
residual_out
=
residual
else
:
residual_out
=
input_tensor
flashinfer_comm
.
allreduce_fusion
(
input
=
input_tensor
,
workspace
=
workspace
,
pattern
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNormFP8Quant
,
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
quant_out
=
quant_out
,
scale_out
=
None
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
scale_factor
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_flashinfer_fused_allreduce_kwargs
(),
)
def
flashinfer_fused_allreduce_rmsnorm_fp4_quant
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
input_global_scale
:
torch
.
Tensor
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
workspace
:
object
,
quant_out
:
torch
.
Tensor
,
use_oneshot
:
bool
,
output_scale
:
torch
.
Tensor
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization.
Note: Only supported by the trtllm backend.
"""
if
flashinfer_comm
is
None
or
workspace
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
norm_out
=
input_tensor
residual_out
=
residual
else
:
residual_out
=
input_tensor
flashinfer_comm
.
allreduce_fusion
(
input
=
input_tensor
,
workspace
=
workspace
,
pattern
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNormFP4Quant
,
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
quant_out
=
quant_out
,
scale_out
=
output_scale
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
input_global_scale
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_flashinfer_fused_allreduce_kwargs
(),
)
class
VllmFusedAllreduce
:
def
__init__
(
self
,
hidden_dim
,
dtype
):
self
.
rms_eps
=
1e-6
self
.
rms_norm
=
RMSNorm
(
hidden_dim
,
eps
=
self
.
rms_eps
,
dtype
=
dtype
)
self
.
fp8_quant
=
QuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
)
def
allreduce_rmsnorm
(
self
,
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
return
self
.
rms_norm
(
allreduce_out
,
residual
)
def
allreduce_rmsnorm_fp8_quant
(
self
,
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
scale_factor
:
torch
.
Tensor
,
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
rms_out
=
self
.
rms_norm
(
allreduce_out
,
residual
)
if
residual
is
None
:
quant_out
=
self
.
fp8_quant
(
rms_out
,
scale_factor
)
return
quant_out
else
:
rms_out
,
residual_out
=
rms_out
quant_out
=
self
.
fp8_quant
(
rms_out
,
scale_factor
)
return
quant_out
,
residual_out
def
allreduce_rmsnorm_fp4_quant
(
self
,
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
input_global_scale
:
torch
.
Tensor
,
quant_out
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
rms_out
=
self
.
rms_norm
(
allreduce_out
,
residual
)
if
residual
is
None
:
SCALED_FP4_QUANT_OP
(
quant_out
,
rms_out
,
output_scale
,
input_global_scale
)
return
quant_out
,
output_scale
else
:
rms_out
,
residual_out
=
rms_out
SCALED_FP4_QUANT_OP
(
quant_out
,
rms_out
,
output_scale
,
input_global_scale
)
return
quant_out
,
residual_out
,
output_scale
def
create_test_tensors
(
num_tokens
:
int
,
hidden_dim
:
int
,
dtype
:
torch
.
dtype
,
use_residual
:
bool
=
True
):
"""Create test tensors for benchmarking."""
input_tensor
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
dtype
=
dtype
)
residual
=
(
torch
.
randn_like
(
input_tensor
)
if
use_residual
else
torch
.
zeros_like
(
input_tensor
)
)
rms_gamma
=
torch
.
ones
(
hidden_dim
,
dtype
=
dtype
)
norm_out
=
None
if
use_residual
else
torch
.
empty_like
(
input_tensor
)
# Quantization scales
scale_fp8
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
scale_fp4
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
quant_out_fp8
=
torch
.
empty_like
(
input_tensor
,
dtype
=
FP8_DTYPE
)
# Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)
fp4_quant_out
=
torch
.
empty
((
num_tokens
,
hidden_dim
//
2
),
dtype
=
torch
.
uint8
)
fp4_output_scale
=
torch
.
empty
((
128
,
4
),
dtype
=
torch
.
int32
)
return
(
input_tensor
,
norm_out
,
residual
,
rms_gamma
,
scale_fp8
,
quant_out_fp8
,
scale_fp4
,
fp4_quant_out
,
fp4_output_scale
,
)
def
benchmark_operation
(
operation_func
,
*
args
,
warmup
:
int
=
5
,
trials
:
int
=
20
,
**
kwargs
):
"""Benchmark a single operation using CUDA graphs."""
# Warmup before graph capture
for
_
in
range
(
warmup
):
operation_func
(
*
args
,
**
kwargs
)
torch
.
cuda
.
synchronize
()
# Create CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
num_op_per_cudagraph
=
10
# Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
with
graph_capture
(
device
=
device
),
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
num_op_per_cudagraph
):
operation_func
(
*
args
,
**
kwargs
)
# Graph warmup
torch
.
cuda
.
synchronize
()
for
_
in
range
(
warmup
):
graph
.
replay
()
# Benchmark with CUDA graph
torch
.
cuda
.
synchronize
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
trials
//
num_op_per_cudagraph
):
# operation_func(*args, **kwargs)
graph
.
replay
()
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
avg_time_ms
=
((
end_time
-
start_time
)
/
trials
)
*
1000
return
avg_time_ms
def
run_benchmarks
(
num_tokens
:
int
,
hidden_dim
:
int
,
dtype
:
torch
.
dtype
,
use_residual
:
bool
,
allreduce_params
:
FlashInferFusedAllReduceParams
|
None
,
workspaces
:
dict
,
quant_modes
:
set
[
str
],
no_oneshot
:
bool
,
):
"""Run all benchmarks for given configuration.
Args:
allreduce_params: Shared parameters for FlashInfer fused allreduce.
workspaces: Dict mapping backend name ("trtllm", "mnnvl") to workspace.
quant_modes: Set of quantization modes: "none", "fp8", "fp4".
"""
(
input_tensor
,
norm_out
,
residual
,
rms_gamma
,
scale_fp8
,
quant_out_fp8
,
scale_fp4
,
fp4_quant_out
,
fp4_output_scale
,
)
=
create_test_tensors
(
num_tokens
,
hidden_dim
,
dtype
,
use_residual
)
rms_eps
=
1e-6
results
=
{}
use_oneshot_options
=
[
False
]
if
no_oneshot
else
[
True
,
False
]
if
"none"
in
quant_modes
:
# Standard AllReduce + RMSNorm
# Re-create VllmFusedAllreduce per config so CustomOp binds the
# correct forward method (native vs custom kernel).
for
custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
custom_op
]))
):
try
:
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
suffix
=
(
"_custom_rms_norm"
if
"+"
in
custom_op
else
"_native_rms_norm"
)
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm
,
input_tensor
,
residual
=
residual
,
)
results
[
f
"standard_allreduce_
{
suffix
}
"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm failed: %s"
,
e
)
results
[
f
"standard_allreduce_
{
suffix
}
"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm Native Compiled
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
]))
):
try
:
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
standard_allreduce_rmsnorm_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm
,
fullgraph
=
True
,
dynamic
=
False
,
)
time_ms
=
benchmark_operation
(
standard_allreduce_rmsnorm_native_compiled
,
input_tensor
,
residual
=
residual
,
)
results
[
"standard_allreduce_rmsnorm_native_compiled"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm (all backends)
for
backend
,
workspace
in
workspaces
.
items
():
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
key
=
f
"flashinfer_
{
backend
}
_fused_allreduce_rmsnorm
{
suffix
}
"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm
,
input_tensor
,
residual
=
residual
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
allreduce_params
=
allreduce_params
,
workspace
=
workspace
,
use_oneshot
=
use_oneshot
,
)
results
[
key
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"FlashInfer (%s) Fused AllReduce+RMSNorm failed: %s"
,
backend
,
e
,
)
results
[
key
]
=
float
(
"inf"
)
if
"fp8"
in
quant_modes
:
# Standard AllReduce + RMSNorm + FP8 Quant
for
rms_norm_custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
suffix
=
(
"_custom_rms_norm"
if
"+"
in
rms_norm_custom_op
else
"_native_rms_norm"
)
for
quant_fp8_custom_op
in
[
"-quant_fp8"
,
"+quant_fp8"
]:
op_suffix
=
suffix
+
(
"_custom_quant_fp8"
if
"+"
in
quant_fp8_custom_op
else
"_native_quant_fp8"
)
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
rms_norm_custom_op
,
quant_fp8_custom_op
]
)
)
):
try
:
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp8_quant
,
input_tensor
,
residual
=
residual
,
scale_factor
=
scale_fp8
,
)
results
[
f
"standard_allreduce
{
op_suffix
}
"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP8 failed: %s"
,
e
)
results
[
f
"standard_allreduce
{
op_suffix
}
"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
,
"-quant_fp8"
]
)
)
):
try
:
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
standard_allreduce_rmsnorm_fp8_quant_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp8_quant
,
fullgraph
=
True
,
dynamic
=
False
,
)
time_ms
=
benchmark_operation
(
standard_allreduce_rmsnorm_fp8_quant_native_compiled
,
input_tensor
,
residual
=
residual
,
scale_factor
=
scale_fp8
,
)
results
[
"standard_allreduce_rmsnorm_fp8_quant_native_compiled"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_fp8_quant_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant (trtllm only)
if
"trtllm"
in
workspaces
:
trtllm_ws
=
workspaces
[
"trtllm"
]
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
key
=
f
"flashinfer_trtllm_fused_allreduce_rmsnorm_fp8_quant
{
suffix
}
"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp8_quant
,
input_tensor
,
norm_out
=
norm_out
,
residual
=
residual
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
scale_factor
=
scale_fp8
,
quant_out
=
quant_out_fp8
,
allreduce_params
=
allreduce_params
,
workspace
=
trtllm_ws
,
use_oneshot
=
use_oneshot
,
)
results
[
key
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP8 failed: %s"
,
e
,
)
results
[
key
]
=
float
(
"inf"
)
if
"fp4"
in
quant_modes
and
current_platform
.
has_device_capability
(
100
):
# Standard AllReduce + RMSNorm + FP4 Quant
for
rms_norm_custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
suffix
=
(
"_custom_rms_norm"
if
"+"
in
rms_norm_custom_op
else
"_native_rms_norm"
)
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
rms_norm_custom_op
]
)
)
):
try
:
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp4_quant
,
input_tensor
,
residual
=
residual
,
input_global_scale
=
scale_fp4
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
)
results
[
f
"standard_allreduce_
{
suffix
}
_fp4_quant"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP4 failed: %s"
,
e
)
results
[
f
"standard_allreduce_
{
suffix
}
_fp4_quant"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm + FP4 Quant Native Compiled
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
]))
):
try
:
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
standard_allreduce_rmsnorm_fp4_quant_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp4_quant
,
fullgraph
=
True
,
dynamic
=
False
,
)
time_ms
=
benchmark_operation
(
standard_allreduce_rmsnorm_fp4_quant_native_compiled
,
input_tensor
,
residual
=
residual
,
quant_out
=
fp4_quant_out
,
input_global_scale
=
scale_fp4
,
output_scale
=
fp4_output_scale
,
)
results
[
"standard_allreduce_rmsnorm_fp4_quant_native_compiled"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_fp4_quant_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant (trtllm only)
if
"trtllm"
in
workspaces
:
trtllm_ws
=
workspaces
[
"trtllm"
]
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
key
=
f
"flashinfer_trtllm_fused_allreduce_rmsnorm_fp4_quant
{
suffix
}
"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp4_quant
,
input_tensor
,
residual
=
residual
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
input_global_scale
=
scale_fp4
,
allreduce_params
=
allreduce_params
,
workspace
=
trtllm_ws
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
use_oneshot
=
use_oneshot
,
)
results
[
key
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP4 failed: %s"
,
e
,
)
results
[
key
]
=
float
(
"inf"
)
return
results
def
prepare_results_with_speedups
(
results_dict
):
"""Prepare results with speedup calculations based on dynamic baseline selection."""
prepared_results
=
[]
# Determine the fastest baseline for each operation type
def
get_fastest_baseline
(
op_name
,
results_dict
):
"""Get the fastest baseline between standard and native_compiled versions."""
if
"fp8_quant"
in
op_name
:
candidates
=
[
"standard_allreduce_rmsnorm_fp8_quant"
,
"standard_allreduce_rmsnorm_fp8_quant_native_compiled"
,
]
elif
"fp4_quant"
in
op_name
:
candidates
=
[
"standard_allreduce_rmsnorm_fp4_quant"
,
"standard_allreduce_rmsnorm_fp4_quant_native_compiled"
,
]
else
:
candidates
=
[
"standard_allreduce_rmsnorm"
,
"standard_allreduce_rmsnorm_native_compiled"
,
]
# Find the fastest among available candidates
fastest_time
=
float
(
"inf"
)
fastest_baseline
=
None
for
candidate
in
candidates
:
if
(
candidate
in
results_dict
and
results_dict
[
candidate
]
!=
float
(
"inf"
)
and
results_dict
[
candidate
]
<
fastest_time
):
fastest_time
=
results_dict
[
candidate
]
fastest_baseline
=
candidate
return
fastest_baseline
# Create dynamic baseline mapping
dynamic_baseline_mapping
=
{}
for
op_name
in
results_dict
:
if
(
op_name
.
startswith
(
"flashinfer_"
)
or
op_name
.
startswith
(
"standard_"
)
and
not
op_name
.
endswith
(
"_native_compiled"
)
):
dynamic_baseline_mapping
[
op_name
]
=
get_fastest_baseline
(
op_name
,
results_dict
)
for
op_name
,
time_ms
in
results_dict
.
items
():
if
time_ms
==
float
(
"inf"
):
speedup_str
=
"FAILED"
time_str
=
"FAILED"
else
:
time_str
=
f
"
{
time_ms
:.
3
f
}
"
# Find the appropriate baseline for this operation
baseline_op
=
dynamic_baseline_mapping
.
get
(
op_name
)
if
baseline_op
and
baseline_op
in
results_dict
:
baseline_time
=
results_dict
[
baseline_op
]
if
baseline_time
!=
float
(
"inf"
)
and
baseline_time
>
0
:
speedup
=
baseline_time
/
time_ms
speedup_str
=
f
"
{
speedup
:.
2
f
}
x"
else
:
speedup_str
=
"N/A"
else
:
# For baseline operations, determine if this is the fastest baseline
if
op_name
.
endswith
(
"_native_compiled"
)
or
(
op_name
.
startswith
(
"standard_"
)
and
not
op_name
.
endswith
(
"_native_compiled"
)
):
fastest_baseline
=
get_fastest_baseline
(
op_name
,
results_dict
)
if
fastest_baseline
==
op_name
:
speedup_str
=
"baseline"
else
:
if
fastest_baseline
and
fastest_baseline
in
results_dict
:
baseline_time
=
results_dict
[
fastest_baseline
]
if
baseline_time
!=
float
(
"inf"
)
and
baseline_time
>
0
:
speedup
=
baseline_time
/
time_ms
speedup_str
=
f
"
{
speedup
:.
2
f
}
x"
else
:
speedup_str
=
"N/A"
else
:
speedup_str
=
"N/A"
else
:
speedup_str
=
"N/A"
prepared_results
.
append
(
{
"operation"
:
op_name
,
"time_ms"
:
time_ms
,
"time_str"
:
time_str
,
"speedup_str"
:
speedup_str
,
}
)
return
prepared_results
def
print_results
(
results_dict
,
num_tokens
,
hidden_dim
,
dtype
,
use_residual
,
quant_modes
,
input_size_mb
,
):
"""Print benchmark results in a formatted table."""
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
f
"Results: num_tokens=
{
num_tokens
}
, hidden_dim=
{
hidden_dim
}
"
f
"(input size:
{
input_size_mb
:.
2
f
}
MB)"
)
print
(
f
"dtype=
{
dtype
}
, residual=
{
'yes'
if
use_residual
else
'no'
}
, "
f
"quant_modes=
{
','
.
join
(
sorted
(
list
(
quant_modes
)))
}
"
)
print
(
f
"
{
'='
*
80
}
"
)
print
(
f
"
{
'Operation'
:
<
50
}
{
'Time (ms)'
:
<
12
}
{
'Speedup'
:
<
10
}
"
)
print
(
f
"
{
'-'
*
80
}
"
)
# Prepare results with speedup calculations
prepared_results
=
prepare_results_with_speedups
(
results_dict
)
for
result
in
prepared_results
:
if
result
[
"time_ms"
]
==
float
(
"inf"
):
time_display
=
result
[
"time_str"
]
else
:
time_display
=
f
"
{
result
[
'time_ms'
]:.
3
f
}
"
print
(
f
"
{
result
[
'operation'
]:
<
50
}
{
time_display
:
<
12
}
{
result
[
'speedup_str'
]:
<
10
}
"
)
def
format_results_markdown
(
all_results
:
list
[
dict
],
world_size
:
int
,
args
:
argparse
.
Namespace
)
->
str
:
"""Format all benchmark results as markdown."""
lines
:
list
[
str
]
=
[]
lines
.
append
(
"# FlashInfer Fused Collective Operations Benchmark Results"
)
lines
.
append
(
""
)
lines
.
append
(
f
"**World Size:**
{
world_size
}
"
)
lines
.
append
(
f
"**Hidden Dimension:**
{
args
.
hidden_dim
}
"
)
lines
.
append
(
f
"**Warmup Iterations:**
{
args
.
warmup
}
"
)
lines
.
append
(
f
"**Benchmark Trials:**
{
args
.
trials
}
"
)
modes
=
","
.
join
(
all_results
[
0
][
"quant_modes"
])
if
all_results
else
"N/A"
lines
.
append
(
f
"**Quantization Modes:**
{
modes
}
"
)
lines
.
append
(
""
)
lines
.
append
(
"---"
)
lines
.
append
(
""
)
for
entry
in
all_results
:
num_tokens
=
entry
[
"num_tokens"
]
dtype
=
entry
[
"dtype"
]
use_residual
=
entry
[
"use_residual"
]
results_dict
=
entry
[
"results"
]
input_size_mb
=
entry
[
"input_size_mb"
]
residual_str
=
"with residual"
if
use_residual
else
"no residual"
lines
.
append
(
f
"## Configuration: num_tokens=
{
num_tokens
}
, dtype=
{
dtype
}
,
{
residual_str
}
"
)
lines
.
append
(
f
"**Input Size:**
{
input_size_mb
:.
2
f
}
MB"
)
lines
.
append
(
""
)
prepared
=
prepare_results_with_speedups
(
results_dict
)
# Build DataFrame for markdown export
rows
=
[
{
"Operation"
:
r
[
"operation"
].
replace
(
"_"
,
" "
).
title
(),
"Time (ms)"
:
r
[
"time_str"
],
"Speedup"
:
r
[
"speedup_str"
],
}
for
r
in
prepared
]
df
=
pd
.
DataFrame
(
rows
)
if
df
.
empty
:
lines
.
append
(
"No results."
)
else
:
lines
.
append
(
df
.
to_markdown
(
index
=
False
))
lines
.
append
(
""
)
return
"
\n
"
.
join
(
lines
)
def
save_results_to_file
(
all_results
:
list
[
dict
],
world_size
:
int
,
args
:
argparse
.
Namespace
,
rank
:
int
):
"""Save benchmark results to markdown file (only on rank 0)."""
if
rank
!=
0
:
return
if
not
all_results
:
logger
.
warning
(
"No results to save"
)
return
output_path
=
args
.
output_file
try
:
markdown_content
=
format_results_markdown
(
all_results
,
world_size
,
args
)
with
open
(
output_path
,
"a"
)
as
f
:
f
.
write
(
markdown_content
)
except
Exception
as
e
:
logger
.
error
(
"Failed to save results to file: %s"
,
e
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark fused collective operations"
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
nargs
=
"+"
,
default
=
[
128
,
512
,
1024
,
2048
],
help
=
"Numbers of tokens to test"
,
)
parser
.
add_argument
(
"--hidden-dim"
,
type
=
int
,
default
=
8192
,
help
=
"Hidden dimension size"
)
parser
.
add_argument
(
"--dtypes"
,
type
=
str
,
nargs
=
"+"
,
default
=
[
"bfloat16"
],
choices
=
[
"float16"
,
"bfloat16"
,
"float32"
],
help
=
"Data types to test"
,
)
parser
.
add_argument
(
"--no-residual"
,
action
=
"store_true"
,
help
=
"Skip residual connection tests"
,
)
parser
.
add_argument
(
"--quant-modes"
,
type
=
str
,
default
=
"none,fp8,fp4"
,
help
=
(
"Comma-separated quantization modes to run: none, fp8, fp4. "
"Default: none,fp8,fp4"
),
)
parser
.
add_argument
(
"--warmup"
,
type
=
int
,
default
=
5
,
help
=
"Number of warmup iterations"
)
parser
.
add_argument
(
"--trials"
,
type
=
int
,
default
=
20
,
help
=
"Number of benchmark trials"
)
parser
.
add_argument
(
"--output-file"
,
type
=
str
,
help
=
"""Output file path for markdown results
(default: benchmark_results_<timestamp>.md)
"""
,
)
parser
.
add_argument
(
"--no-oneshot"
,
action
=
"store_true"
,
help
=
"Skip oneshot benchmarks"
,
)
args
=
parser
.
parse_args
()
# Check if running with torchrun (required for collective operations)
if
"RANK"
not
in
os
.
environ
or
"WORLD_SIZE"
not
in
os
.
environ
:
raise
RuntimeError
(
"Must run with torchrun for distributed benchmarking. "
"Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py"
)
# Initialize distributed environment
rank
=
int
(
os
.
environ
[
"RANK"
])
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Validate world size (must be > 1 for collective operations)
if
world_size
<=
1
:
raise
ValueError
(
"World size must be > 1 for collective operations benchmarking. "
f
"Current world size:
{
world_size
}
. Use torchrun with --nproc_per_node > 1."
)
# Parse quantization modes
valid_quant_modes
=
{
"none"
,
"fp8"
,
"fp4"
}
raw_modes
=
[
m
.
strip
().
lower
()
for
m
in
(
args
.
quant_modes
or
""
).
split
(
","
)
if
m
.
strip
()
]
quant_modes
=
set
(
raw_modes
)
if
raw_modes
else
{
"none"
,
"fp8"
,
"fp4"
}
invalid
=
sorted
(
list
(
quant_modes
-
valid_quant_modes
))
if
invalid
:
raise
ValueError
(
f
"Invalid --quant-modes entries:
{
','
.
join
(
invalid
)
}
. "
f
"Valid options are:
{
','
.
join
(
sorted
(
valid_quant_modes
))
}
."
)
if
rank
==
0
:
logger
.
info
(
"Running benchmark with world_size=%s, rank=%s"
,
world_size
,
rank
)
logger
.
info
(
"Quantization modes: %s"
,
","
.
join
(
sorted
(
list
(
quant_modes
))))
if
flashinfer_comm
is
not
None
:
logger
.
info
(
"FlashInfer available - will benchmark fused operations"
,
)
else
:
logger
.
info
(
"FlashInfer not available - only benchmarking standard operations"
)
# Convert dtype strings to torch dtypes
dtype_map
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
,
"float32"
:
torch
.
float32
,
}
dtypes
=
[
dtype_map
[
dt
]
for
dt
in
args
.
dtypes
]
# Test configurations
residual_options
=
[
True
]
if
not
args
.
no_residual
else
[
False
]
configs
=
list
(
itertools
.
product
(
args
.
num_tokens
,
dtypes
,
residual_options
))
# Setup FlashInfer workspaces for all backends
allreduce_params
=
None
if
flashinfer_comm
is
not
None
:
# Use the largest hidden dimension for workspace setup
max_element_size
=
max
(
torch
.
finfo
(
dt
).
bits
//
8
for
dt
in
dtypes
)
workspace_dtype
=
(
torch
.
float32
if
max_element_size
==
4
else
(
torch
.
bfloat16
if
torch
.
bfloat16
in
dtypes
else
torch
.
float16
)
)
max_num_token
=
_FI_MAX_SIZES
.
get
(
world_size
)
//
(
args
.
hidden_dim
*
max_element_size
)
for
backend
in
FLASHINFER_BACKENDS
:
setup_flashinfer_workspace
(
backend
=
backend
,
world_size
=
world_size
,
rank
=
rank
,
hidden_dim
=
args
.
hidden_dim
,
max_token_num
=
max_num_token
,
dtype
=
workspace_dtype
,
)
if
_FI_WORKSPACES
:
allreduce_params
=
FlashInferFusedAllReduceParams
(
max_token_num
=
max_num_token
,
)
# Collect all results for markdown export
all_results
=
[]
try
:
# Run benchmarks
for
num_tokens
,
dtype
,
use_residual
in
configs
:
if
rank
==
0
:
logger
.
info
(
"
\n
Testing: num_tokens=%s, hidden_dim=%s, dtype=%s, residual=%s"
,
num_tokens
,
args
.
hidden_dim
,
dtype
,
use_residual
,
)
results
=
run_benchmarks
(
num_tokens
,
args
.
hidden_dim
,
dtype
,
use_residual
,
allreduce_params
,
workspaces
=
_FI_WORKSPACES
,
quant_modes
=
quant_modes
,
no_oneshot
=
args
.
no_oneshot
,
)
# Store results for markdown export
if
rank
==
0
:
# Calculate input size in MB
input_size_mb
=
(
num_tokens
*
args
.
hidden_dim
*
torch
.
finfo
(
dtype
).
bits
)
/
(
8
*
1024
*
1024
)
all_results
.
append
(
{
"num_tokens"
:
num_tokens
,
"hidden_dim"
:
args
.
hidden_dim
,
"dtype"
:
str
(
dtype
).
replace
(
"torch."
,
""
),
"use_residual"
:
use_residual
,
"quant_modes"
:
sorted
(
list
(
quant_modes
)),
"input_size_mb"
:
input_size_mb
,
"results"
:
results
,
}
)
print_results
(
results
,
num_tokens
,
args
.
hidden_dim
,
dtype
,
use_residual
,
quant_modes
,
input_size_mb
,
)
# Save results to markdown file
if
args
.
output_file
and
rank
==
0
:
save_results_to_file
(
all_results
,
world_size
,
args
,
rank
)
finally
:
# Cleanup
cleanup_flashinfer_workspaces
()
dist
.
barrier
()
if
__name__
==
"__main__"
:
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
with
set_current_vllm_config
(
VllmConfig
()):
main
()
benchmarks/kernels/benchmark_fused_topk.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
torch
from
vllm.model_executor.layers.fused_moe.router.fused_topk_router
import
fused_topk
from
vllm.triton_utils
import
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
num_tokens_range
=
[
2
**
i
for
i
in
range
(
0
,
8
,
2
)]
num_experts_range
=
[
16
,
32
,
64
,
128
,
256
,
512
]
topk_range
=
[
3
,
4
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
))
def
torch_topk
(
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
scoring_func
:
str
=
"softmax"
,
):
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
gating_output
.
float
(),
dim
=-
1
)
else
:
scores
=
torch
.
sigmoid
(
gating_output
.
float
())
topk_weights
,
topk_ids
=
torch
.
topk
(
scores
,
k
=
topk
,
dim
=-
1
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
def
get_benchmark
(
scoring_func
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"torch"
,
"vllm"
],
line_names
=
[
"Torch"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
f
"fused-topk-perf-
{
scoring_func
}
"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
num_experts
,
topk
,
provider
):
dtype
=
torch
.
bfloat16
hidden_size
=
1024
renormalize
=
True
hidden_states
=
torch
.
randn
(
(
num_tokens
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
)
gating_output
=
torch
.
randn
(
(
num_tokens
,
num_experts
),
dtype
=
dtype
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch_topk
(
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
scoring_func
=
scoring_func
,
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
scoring_func
=
scoring_func
,
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the MoE topk kernel."
)
parser
.
add_argument
(
"--scoring-func"
,
type
=
str
,
default
=
"softmax"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/fused_topk/"
)
args
=
parser
.
parse_args
()
# Get the benchmark function
benchmark
=
get_benchmark
(
args
.
scoring_func
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.utils.benchmark
as
benchmark
from
benchmark_shapes
import
WEIGHT_SHAPES_MOE
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
tests.kernels.moe.utils
import
make_dummy_moe_config
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.all2all_utils
import
(
maybe_make_prepare_finalize
,
)
from
vllm.model_executor.layers.fused_moe.config
import
fp8_w8a8_moe_quant_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassExpertsFp8
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_topk
,
)
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.v1.worker.workspace
import
init_workspace_manager
DEFAULT_MODELS
=
[
"mistralai/Mixtral-8x7B-Instruct-v0.1"
,
"deepseek-ai/DeepSeek-V2-Lite"
,
"ibm-granite/granite-3.0-1b-a400m"
,
"ibm-granite/granite-3.0-3b-a800m"
,
]
DEFAULT_BATCH_SIZES
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
]
DEFAULT_TP_SIZES
=
[
1
]
PER_ACT_TOKEN_OPTS
=
[
False
]
PER_OUT_CH_OPTS
=
[
False
]
def
to_fp8
(
tensor
:
torch
.
Tensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
bench_run
(
results
:
list
[
benchmark
.
Measurement
],
model
:
str
,
num_experts
:
int
,
topk
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
mkn
:
tuple
[
int
,
int
,
int
],
):
init_workspace_manager
(
torch
.
cuda
.
current_device
())
label
=
"Quant Matmul"
sub_label
=
(
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})"
.
format
(
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
)
)
print
(
f
"Testing:
{
sub_label
}
"
)
(
m
,
k
,
n
)
=
mkn
dtype
=
torch
.
half
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
num_experts
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
num_experts
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
_
,
a_scale
=
ops
.
scaled_fp8_quant
(
a
)
w1_q
=
torch
.
empty
(
(
num_experts
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w2_q
=
torch
.
empty
((
num_experts
,
k
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w1_scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
num_experts
):
w1_q
[
expert
],
w1_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w1
[
expert
])
w2_q
[
expert
],
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w2
[
expert
])
score
=
torch
.
randn
((
m
,
num_experts
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
False
)
def
run_triton_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
,
num_repeats
:
int
,
):
quant_config
=
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_scale
,
)
for
_
in
range
(
num_repeats
):
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
quant_config
=
quant_config
,
)
def
run_cutlass_moe
(
a
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
per_act_token
:
bool
,
num_repeats
:
int
,
):
quant_config
=
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
per_act_token_quant
=
per_act_token
,
)
moe_config
=
make_dummy_moe_config
(
num_experts
=
w2
.
shape
[
0
],
hidden_dim
=
w2
.
shape
[
1
],
intermediate_size_per_partition
=
w2
.
shape
[
2
],
in_dtype
=
a
.
dtype
,
)
fn
=
mk
.
FusedMoEKernel
(
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
quant_config
,
allow_new_interface
=
True
,
use_monolithic
=
False
,
),
CutlassExpertsFp8
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
for
_
in
range
(
num_repeats
):
fn
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
)
def
run_cutlass_from_graph
(
a
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
):
quant_config
=
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
per_act_token_quant
=
per_act_token
,
)
moe_config
=
make_dummy_moe_config
(
num_experts
=
w2
.
shape
[
0
],
hidden_dim
=
w2
.
shape
[
1
],
intermediate_size_per_partition
=
w2
.
shape
[
2
],
in_dtype
=
a
.
dtype
,
)
fn
=
mk
.
FusedMoEKernel
(
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
quant_config
,
allow_new_interface
=
True
,
use_monolithic
=
False
,
),
CutlassExpertsFp8
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
):
return
fn
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
)
def
run_triton_from_graph
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
,
):
quant_config
=
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_scale
,
)
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
):
return
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
quant_config
=
quant_config
,
)
def
replay_graph
(
graph
,
num_repeats
):
for
_
in
range
(
num_repeats
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
cutlass_stream
=
torch
.
cuda
.
Stream
()
cutlass_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
cutlass_graph
,
stream
=
cutlass_stream
):
run_cutlass_from_graph
(
a
,
a_scale
,
w1_q
,
w2_q
,
w1_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
)
torch
.
cuda
.
synchronize
()
triton_stream
=
torch
.
cuda
.
Stream
()
triton_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
triton_graph
,
stream
=
triton_stream
):
run_triton_from_graph
(
a
,
w1_q
,
w2_q
,
topk_weights
,
topk_ids
,
w1_scale
,
w2_scale
,
a_scale
,
)
torch
.
cuda
.
synchronize
()
min_run_time
=
5
num_warmup
=
5
num_runs
=
25
globals
=
{
# Baseline params
"w1"
:
w1
,
"w2"
:
w2
,
"score"
:
score
,
"topk"
:
topk
,
# Cutlass params
"a_scale"
:
a_scale
,
"w1_q"
:
w1_q
,
"w2_q"
:
w2_q
,
"w1_scale"
:
w1_scale
,
"w2_scale"
:
w2_scale
,
"per_act_token"
:
per_act_token
,
# cuda graph params
"cutlass_graph"
:
cutlass_graph
,
"triton_graph"
:
triton_graph
,
# Gen params
"a"
:
a
,
"topk_weights"
:
topk_weights
,
"topk_ids"
:
topk_ids
,
"num_runs"
:
num_runs
,
# Kernels
"run_triton_moe"
:
run_triton_moe
,
"run_cutlass_moe"
:
run_cutlass_moe
,
"replay_graph"
:
replay_graph
,
}
# Warmup
run_triton_moe
(
a
,
w1_q
,
w2_q
,
topk_weights
,
topk_ids
,
w1_scale
,
w2_scale
,
a_scale
,
num_warmup
,
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
replay_graph
(
triton_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(triton_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
run_cutlass_moe
(
a
,
a_scale
,
w1_q
,
w2_q
,
w1_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
per_act_token
,
num_warmup
,
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"grouped_gemm_moe"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
replay_graph
(
cutlass_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(cutlass_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"grouped_gemm_moe_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
def
main
(
args
):
# Initialize workspace manager (required for CUTLASS MoE kernels)
device
=
torch
.
device
(
"cuda:0"
)
init_workspace_manager
(
device
)
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
results
:
list
[
benchmark
.
Measurement
]
=
[]
for
model
in
args
.
models
:
for
tp
in
args
.
tp_sizes
:
for
layer
in
WEIGHT_SHAPES_MOE
[
model
]:
num_experts
=
layer
[
0
]
topk
=
layer
[
1
]
size_k
=
layer
[
2
]
size_n
=
layer
[
3
]
//
tp
if
len
(
args
.
limit_k
)
>
0
and
size_k
not
in
args
.
limit_k
:
continue
if
len
(
args
.
limit_n
)
>
0
and
size_n
not
in
args
.
limit_n
:
continue
for
per_act_token
in
PER_ACT_TOKEN_OPTS
:
for
per_out_ch
in
PER_OUT_CH_OPTS
:
for
size_m
in
DEFAULT_BATCH_SIZES
:
mkn
=
(
size_m
,
size_k
,
size_n
)
bench_run
(
results
,
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
,
)
compare
=
benchmark
.
Compare
(
results
)
compare
.
print
()
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark Marlin across specified models/shapes/batches"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES_MOE
.
keys
(),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-num-groups"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-act-token"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-out-ch"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_int8_gemm.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_int8_quant
as
vllm_scaled_int8_quant
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"int8-tensor-w-token-a"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
False
),
"int8-tensor-w-tensor-a"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
True
),
"int8-channel-w-token-a"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
True
),
"int8-channel-w-tensor-a"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
False
),
"int8-tensor-w-token-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
False
),
"int8-tensor-w-tensor-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
True
),
"int8-channel-w-token-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
True
),
"int8-channel-w-tensor-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
False
),
}
def
_quant_weight
(
b
,
w_type
,
device
):
if
w_type
==
"tensor"
:
scale_b
=
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
,
scale_b
)
assert
scale_b_int8
.
numel
()
==
1
else
:
# channel
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
)
assert
scale_b_int8
.
numel
()
==
b
.
shape
[
0
]
return
b_int8
.
t
(),
scale_b_int8
def
build_int8_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
# quant before running the kernel
b_int8
,
scale_b_int8
=
_quant_weight
(
b
,
cfg
[
"w"
],
device
)
scale_a_const
=
None
if
cfg
[
"a"
]
==
"tensor"
:
scale_a_const
=
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
# no quant, create activation ahead
if
cfg
[
"no_a_quant"
]:
if
cfg
[
"a"
]
==
"tensor"
:
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a_const
)
else
:
# token
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
def
run_quant
():
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
return
run_quant
# dynamic quant, create activation inside
if
cfg
[
"a"
]
==
"tensor"
:
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a_const
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
else
:
# token
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
return
run_quant
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
.
get
(
"enabled"
)]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
[
k
for
k
in
_enabled
],
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs INT8 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_int8_runner
(
cfg
,
a
,
b
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
KN_model_names
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
KN_model_names
.
append
(
KN
)
return
KN_model_names
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
)
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs INT8 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_int8_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
benchmarks/kernels/benchmark_layernorm.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
torch
from
vllm.benchmarks.lib.utils
import
default_vllm_config
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
set_random_seed
@
torch
.
inference_mode
()
@
default_vllm_config
()
def
main
(
num_tokens
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
do_profile
:
bool
=
False
,
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
,
)
->
None
:
set_random_seed
(
seed
)
torch
.
set_default_device
(
"cuda"
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
x
*=
scale
residual
=
torch
.
randn_like
(
x
)
*
scale
if
add_residual
else
None
def
run_cuda_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
num_iters
):
layer
(
x
,
residual
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
print
(
"Warming up..."
)
run_benchmark
=
run_cuda_benchmark
run_benchmark
(
num_iters
=
num_warmup_iters
,
profile
=
False
)
# Benchmark.
if
do_profile
:
latency
=
run_benchmark
(
num_iters
=
1
,
profile
=
True
)
else
:
latency
=
run_benchmark
(
num_iters
=
num_iters
,
profile
=
False
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the layernorm kernel."
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--add-residual"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-warmup-iters"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-iters"
,
type
=
int
,
default
=
100
,
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
,
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
num_tokens
=
args
.
num_tokens
,
hidden_size
=
args
.
hidden_size
,
add_residual
=
args
.
add_residual
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
,
)
benchmarks/kernels/benchmark_lora.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
json
import
pickle
import
time
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
itertools
import
product
from
pathlib
import
Path
from
typing
import
Any
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
utils
import
ArgPool
,
Bench
,
CudaGraphBenchParams
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm.lora.ops.triton_ops.utils
import
get_lora_op_configs
from
vllm.triton_utils
import
HAS_TRITON
,
triton
if
HAS_TRITON
:
from
vllm.lora.ops.triton_ops
import
(
## added fused_moe_lora
LoRAKernelMeta
,
fused_moe_lora_expand
,
fused_moe_lora_shrink
,
lora_expand
,
lora_shrink
,
)
from
vllm.lora.ops.triton_ops.fused_moe_lora_op
import
(
_LORA_PTR_DICT
,
## added _LORA_PTR_DICT for fused_moe_lora
)
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm
import
_custom_ops
as
ops
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.math_utils
import
round_up
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_TP_SIZES
=
[
1
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
192
,
256
,
320
,
384
,
448
,
512
,
640
,
768
,
896
,
1024
,
2048
,
3072
,
4096
,
5120
,
6144
,
7168
,
8192
,
]
DEFAULT_HIDDEN_SIZES
=
[
1024
,
2048
,
4096
,
8192
,
16384
]
DEFAULT_LORA_RANKS
=
[
16
]
DEFAULT_NUM_LORAS
=
[
1
,
2
,
3
,
4
]
DEFAULT_SORT_BY_LORA_IDS
=
[
False
,
True
]
DEFAULT_SEQ_LENGTHS
=
[
1
]
DEFAULT_EXPAND_FN_ADD_INPUTS
=
[
True
,
False
]
DEFAULT_TOP_K_NUMS
=
[
1
]
# Added for MoE LoRA top_k
DEFAULT_NUM_EXPERTS
=
[
8
]
# Added for MoE LoRA num_experts
# Utilities
def
dtype_to_str
(
dtype
:
torch
.
dtype
):
if
dtype
==
torch
.
float16
:
return
"f16"
if
dtype
==
torch
.
bfloat16
:
return
"bf16"
if
dtype
==
torch
.
float32
:
return
"f32"
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
"
)
def
make_rand_lora_weight_tensor
(
k
:
int
,
n
:
int
,
num_loras
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
=
"cuda"
)
->
torch
.
Tensor
:
# LoRA weights column major
return
torch
.
rand
((
num_loras
,
n
,
k
),
dtype
=
dtype
).
to
(
device
)
def
make_rand_tensors
(
a_shape
:
tuple
[
int
,
...],
b_shape
:
tuple
[
int
,
...],
c_shape
:
tuple
[
int
,
...],
a_dtype
:
torch
.
dtype
,
b_dtype
:
torch
.
dtype
,
c_dtype
:
torch
.
dtype
,
num_slices
:
int
,
device
:
str
=
"cuda"
,
)
->
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
"""
Make LoRA input/output matrices.
"""
A
=
torch
.
rand
(
a_shape
,
dtype
=
a_dtype
).
to
(
device
)
# LoRA weights column major
Bs
=
[
torch
.
rand
(
b_shape
,
dtype
=
b_dtype
).
to
(
device
)
for
_
in
range
(
num_slices
)]
C
=
torch
.
zeros
(
c_shape
,
dtype
=
c_dtype
).
to
(
device
)
return
A
,
Bs
,
C
def
make_prompt_lora_mapping
(
num_prompts
:
int
,
num_active_loras
:
int
,
sort_by_lora_id
:
bool
,
device
:
str
)
->
torch
.
Tensor
:
"""
All prompts are mapped to a LoRA ID in range [0, num_active_loras).
where 0 refers to first lora, 1 refers to second lora and so on.
"""
assert
num_active_loras
>
0
if
not
sort_by_lora_id
:
return
torch
.
randint
(
0
,
num_active_loras
,
(
num_prompts
,),
dtype
=
torch
.
long
)
# Divide LoRAs equally and in order.
part_size
=
num_prompts
//
num_active_loras
part_size
=
max
(
part_size
,
1
)
lora_id
=
0
prompt_lora_mapping
=
[]
while
len
(
prompt_lora_mapping
)
<
num_prompts
:
prompt_lora_mapping
.
extend
([
lora_id
]
*
part_size
)
lora_id
=
lora_id
+
1
if
lora_id
+
1
<
num_active_loras
else
lora_id
return
torch
.
tensor
(
prompt_lora_mapping
[:
num_prompts
],
dtype
=
torch
.
long
,
device
=
device
)
def
make_token_lora_mapping
(
num_tokens
:
int
,
num_prompts
:
int
,
prompt_lora_mapping
:
torch
.
Tensor
,
seq_len_tensor
:
torch
.
Tensor
,
device
:
str
,
):
"""
Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor
"""
assert
prompt_lora_mapping
.
shape
[
0
]
==
num_prompts
# token to lora index mapping
token_lora_mapping
=
[
0
]
*
num_tokens
current_offset
=
0
for
b_id
in
range
(
num_prompts
):
lora_index
=
prompt_lora_mapping
[
b_id
].
item
()
s
=
current_offset
e
=
s
+
seq_len_tensor
[
b_id
].
item
()
token_lora_mapping
[
s
:
e
]
=
[
lora_index
]
*
(
e
-
s
)
current_offset
+=
seq_len_tensor
[
b_id
].
item
()
return
torch
.
tensor
(
token_lora_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
def
ref_group_gemm
(
ref_out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
lora_weights
:
list
[
torch
.
Tensor
],
seq_lens_cpu
:
torch
.
Tensor
,
prompt_lora_mapping_cpu
:
torch
.
Tensor
,
scaling
:
float
,
add_inputs
:
bool
|
None
,
):
"""
Torch group gemm reference implementation to test correctness of
benchmarking operations.
"""
batches
=
seq_lens_cpu
.
size
(
0
)
out_list
=
[]
current_offset
=
0
for
lora_index
,
b_length
in
zip
(
range
(
batches
),
seq_lens_cpu
):
x
=
input
[
current_offset
:
b_length
+
current_offset
,
:]
current_offset
+=
b_length
w
=
lora_weights
[
prompt_lora_mapping_cpu
[
lora_index
]]
result
=
torch
.
nn
.
functional
.
linear
(
x
,
w
)
result
*=
scaling
out_list
.
append
(
result
)
cat_result
=
torch
.
cat
(
out_list
,
dim
=
0
)
if
add_inputs
:
ref_out
+=
cat_result
else
:
ref_out
.
copy_
(
cat_result
)
class
OpType
(
Enum
):
"""
LoRA Ops to benchmark and its properties.
"""
LORA_SHRINK
=
auto
()
LORA_EXPAND
=
auto
()
## Adding support for fused moe lora
FUSED_MOE_LORA_GATE_UP_SHRINK
=
auto
()
## Gate/Up projection variant with shrink
FUSED_MOE_LORA_GATE_UP_EXPAND
=
auto
()
## Gate/Up projection variant with expand
FUSED_MOE_LORA_DOWN_SHRINK
=
auto
()
## Down projection variant with shrink
FUSED_MOE_LORA_DOWN_EXPAND
=
auto
()
## Down projection variant with expand
@
staticmethod
def
from_str
(
s
:
str
)
->
"OpType"
:
if
s
.
lower
()
==
"lora_shrink"
:
return
OpType
.
LORA_SHRINK
if
s
.
lower
()
==
"lora_expand"
:
return
OpType
.
LORA_EXPAND
# Adding support for fused moe lora, both in gate_up and down
if
s
.
lower
()
==
"fused_moe_lora_gate_up_shrink"
:
## Gate/Up variant with shrink
return
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
if
s
.
lower
()
==
"fused_moe_lora_gate_up_expand"
:
## Gate/Up variant with expand
return
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
if
s
.
lower
()
==
"fused_moe_lora_down_shrink"
:
## Down variant with shrink
return
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
if
s
.
lower
()
==
"fused_moe_lora_down_expand"
:
## Down variant with expand
return
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
raise
ValueError
(
f
"Unrecognized str
{
s
}
to convert to OpType"
)
def
is_shrink_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
LORA_SHRINK
]
def
is_expand_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
LORA_EXPAND
]
def
is_fused_moe_lora_fn
(
self
)
->
bool
:
## adding for fused MoE LoRA
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]
def
is_fused_moe_lora_gate_up_fn
(
self
,
)
->
bool
:
## adding for fused MoE LoRA Gate/Up
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
]
def
is_fused_moe_lora_down_fn
(
self
)
->
bool
:
## adding for fused MoE LoRA Down
return
self
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]
def
is_fused_moe_lora_shrink_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
]
def
is_fused_moe_lora_expand_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]
def
num_slices
(
self
)
->
list
[
int
]:
if
self
.
is_fused_moe_lora_gate_up_fn
():
return
[
2
]
elif
self
.
is_fused_moe_lora_down_fn
():
return
[
1
]
return
[
1
,
2
,
3
]
def
mkn
(
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
)
->
tuple
[
int
,
int
,
int
]:
num_tokens
=
batch_size
*
seq_length
if
self
.
is_shrink_fn
()
or
self
.
is_fused_moe_lora_fn
():
m
=
num_tokens
k
=
hidden_size
n
=
lora_rank
elif
self
.
is_expand_fn
():
m
=
num_tokens
k
=
lora_rank
n
=
hidden_size
return
m
,
k
,
n
def
matmul_dtypes
(
self
,
op_dtype
:
torch
.
dtype
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
,
torch
.
dtype
]:
"""
return a type, b type and c type for A x B = C
"""
if
self
.
is_shrink_fn
():
return
op_dtype
,
op_dtype
,
torch
.
float32
elif
self
.
is_expand_fn
():
return
torch
.
float32
,
op_dtype
,
op_dtype
else
:
assert
self
.
is_fused_moe_lora_fn
()
return
op_dtype
,
op_dtype
,
op_dtype
def
matmul_shapes_fused_moe_lora
(
self
,
m
:
int
,
n
:
int
,
k
:
int
,
num_loras
:
int
,
num_slices
:
int
,
top_k_num
:
int
,
num_experts
:
int
,
)
->
tuple
[
tuple
[
int
],
tuple
[
int
],
tuple
[
int
],
tuple
[
int
]]:
if
self
.
is_fused_moe_lora_shrink_fn
():
input_shape
=
(
(
m
*
top_k_num
,
n
)
if
self
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
]
else
(
m
,
n
)
)
output_shape
=
(
num_slices
,
m
,
top_k_num
,
k
)
weight_shape
=
(
num_loras
,
num_experts
,
k
,
n
)
else
:
assert
self
.
is_fused_moe_lora_expand_fn
()
input_shape
=
(
num_slices
,
m
,
top_k_num
,
k
)
output_shape
=
(
m
,
top_k_num
,
n
*
num_slices
)
weight_shape
=
(
num_loras
,
num_experts
,
n
,
k
)
return
(
input_shape
,
weight_shape
,
output_shape
)
def
matmul_shapes
(
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
,
num_loras
:
int
,
num_slices
:
int
,
top_k_num
:
int
|
None
=
None
,
num_experts
:
int
|
None
=
None
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
"""
Given num_slices, return the shapes of the A, B, and C matrices
in A x B = C, for the op_type
"""
m
,
k
,
n
=
self
.
mkn
(
batch_size
,
seq_length
,
hidden_size
,
lora_rank
)
b_shape
=
(
num_loras
,
n
,
k
)
# col-major
if
self
in
[
OpType
.
LORA_SHRINK
]:
# LoRA shrink kernels support num_slices inherently in the kernel.
return
((
m
,
k
),
b_shape
,
(
num_slices
,
m
,
n
))
if
self
in
[
OpType
.
LORA_EXPAND
]:
# LoRA expand kernels support num_slices inherently in the kernel
return
((
num_slices
,
m
,
k
),
b_shape
,
(
m
,
n
*
num_slices
))
if
self
.
is_fused_moe_lora_fn
():
return
self
.
matmul_shapes_fused_moe_lora
(
m
,
k
,
n
,
num_loras
,
num_slices
,
top_k_num
,
num_experts
,
)
raise
ValueError
(
f
"Unrecognized op_type
{
self
}
"
)
def
bench_fn
(
self
)
->
Callable
:
if
self
==
OpType
.
LORA_SHRINK
:
return
lora_shrink
if
self
==
OpType
.
LORA_EXPAND
:
return
lora_expand
if
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
]:
return
fused_moe_lora_shrink
if
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]:
return
fused_moe_lora_expand
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
def
run_ref_group_gemm
(
self
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
lora_weights
:
list
[
torch
.
Tensor
],
**
kwargs
,
)
->
Callable
:
"""Each benchmark operation expects the input, lora_weights and outputs
in a slightly different format. Refer to self.matmul_shapes().
run_ref_group_gemm accounts for those differences in executing a
reference group gemm for correctness testing.
"""
w_dtype
=
lora_weights
[
0
].
dtype
num_slices
=
len
(
lora_weights
)
if
self
in
[
OpType
.
LORA_SHRINK
]:
for
slice_idx
in
range
(
num_slices
):
ref_group_gemm
(
ref_out
=
output
[
slice_idx
,
:],
input
=
input
,
lora_weights
=
lora_weights
[
slice_idx
],
**
kwargs
,
)
elif
self
in
[
OpType
.
LORA_EXPAND
]:
hidden_size
=
lora_weights
[
0
].
shape
[
1
]
for
slice_idx
in
range
(
num_slices
):
slice_offset
=
slice_idx
*
hidden_size
ref_group_gemm
(
ref_out
=
output
[:,
slice_offset
:
slice_offset
+
hidden_size
],
input
=
input
[
slice_idx
].
clone
().
to
(
dtype
=
w_dtype
),
lora_weights
=
lora_weights
[
slice_idx
],
**
kwargs
,
)
else
:
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
@
dataclass
class
BenchmarkContext
:
"""
LoRA benchmark context
"""
batch_size
:
int
hidden_size
:
int
num_loras
:
int
num_active_loras
:
int
lora_rank
:
int
sort_by_lora_id
:
bool
dtype
:
torch
.
dtype
seq_length
:
int
|
None
=
None
num_experts
:
int
|
None
=
None
# num_experts for MoE based ops
top_k_num
:
int
|
None
=
None
# top_k for MoE based ops
num_slices
:
int
|
None
=
None
# num_slices for slice based ops
def
with_seq_length
(
self
,
seq_length
:
int
)
->
"BenchmarkContext"
:
ctx
=
copy
.
copy
(
self
)
ctx
.
seq_length
=
seq_length
return
ctx
def
with_num_slices
(
self
,
num_slices
:
int
)
->
"BenchmarkContext"
:
ctx
=
copy
.
copy
(
self
)
ctx
.
num_slices
=
num_slices
return
ctx
def
bench_label
(
self
)
->
str
:
return
f
"lora-
{
self
.
dtype
}
"
def
bench_sublabel
(
self
,
op_type
:
OpType
)
->
str
:
m
,
k
,
n
=
op_type
.
mkn
(
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
,
self
.
lora_rank
)
desc
=
{
"bs"
:
self
.
batch_size
,
"sl"
:
self
.
seq_length
,
"m"
:
m
,
"k"
:
k
,
"n"
:
n
,
"num_loras"
:
self
.
num_loras
,
"sort_by_lora"
:
self
.
sort_by_lora_id
,
"num_slices"
:
self
.
num_slices
,
}
return
json
.
dumps
(
desc
)
@
dataclass
class
BenchmarkTensors
:
"""
Input/Output tensors used for benchmarks
"""
# matmul tensors
input
:
torch
.
Tensor
lora_weights_lst
:
list
[
torch
.
Tensor
]
output
:
torch
.
Tensor
# LoRA kernel metadata
lora_kernel_meta
:
LoRAKernelMeta
# Metadata tensors used in testing correctness
seq_lens
:
torch
.
Tensor
prompt_lora_mapping
:
torch
.
Tensor
def
io_types
(
self
)
->
str
:
return
(
f
"
{
dtype_to_str
(
self
.
input
.
dtype
)
}
x"
f
"
{
dtype_to_str
(
self
.
lora_weights_lst
[
0
].
dtype
)
}
=>"
f
"
{
dtype_to_str
(
self
.
output
.
dtype
)
}
"
)
def
get_num_tokens
(
self
,
size
:
int
,
top_k_num
:
int
,
op_type
:
OpType
):
return
(
size
*
top_k_num
if
op_type
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
]
else
size
)
@
staticmethod
def
make
(
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
device
:
str
=
"cuda"
)
->
"BenchmarkTensors"
:
# Make input / output matmul tensors.
a_shape
,
b_shape
,
c_shape
=
op_type
.
matmul_shapes
(
ctx
.
batch_size
,
ctx
.
seq_length
,
ctx
.
hidden_size
,
ctx
.
lora_rank
,
ctx
.
num_loras
,
ctx
.
num_slices
,
ctx
.
top_k_num
,
ctx
.
num_experts
,
)
a_type
,
b_type
,
c_type
=
op_type
.
matmul_dtypes
(
ctx
.
dtype
)
input_tensor
,
lora_weights
,
output_tensor
=
make_rand_tensors
(
a_shape
,
b_shape
,
c_shape
,
a_type
,
b_type
,
c_type
,
num_slices
=
ctx
.
num_slices
)
# Make metadata tensors.
# Keep the metadata tensors in the CPU for further processing if needed.
# The tensors get moved to the GPU before benchmarking.
assert
ctx
.
num_active_loras
<=
ctx
.
num_loras
total_tokens
=
ctx
.
batch_size
*
ctx
.
seq_length
# Make metadata tensors involved in correctness testing.
# Prepare seq lens tensor
seq_len_tensor
=
torch
.
randint
(
ctx
.
seq_length
,
ctx
.
seq_length
+
1
,
(
ctx
.
batch_size
,)
)
assert
total_tokens
==
seq_len_tensor
.
sum
()
# Prepare prompt lora indices tensor
prompt_lora_indices_tensor
=
make_prompt_lora_mapping
(
ctx
.
batch_size
,
ctx
.
num_active_loras
,
ctx
.
sort_by_lora_id
,
"cpu"
)
# Make LoRAKernelMeta
token_lora_indices_tensor
=
make_token_lora_mapping
(
total_tokens
,
ctx
.
batch_size
,
prompt_lora_indices_tensor
,
seq_len_tensor
,
"cpu"
,
)
lora_kernel_meta
=
LoRAKernelMeta
.
make
(
max_loras
=
ctx
.
num_loras
,
max_num_tokens
=
token_lora_indices_tensor
.
size
(
0
),
device
=
"cpu"
,
)
lora_kernel_meta
.
prepare_tensors
(
token_lora_mapping
=
token_lora_indices_tensor
)
return
BenchmarkTensors
(
input_tensor
,
lora_weights
,
output_tensor
,
lora_kernel_meta
,
seq_len_tensor
,
prompt_lora_indices_tensor
,
)
def
sanity_check
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
None
:
"""
Fails asserts when non-conformality is detected.
"""
num_tokens
=
(
self
.
input
.
shape
[
1
]
if
op_type
.
is_fused_moe_lora_expand_fn
()
else
self
.
input
.
shape
[
-
2
]
)
# check metadata tensors
## In down shrink case, each token is repeated top_k_num times
assert
num_tokens
==
self
.
get_num_tokens
(
torch
.
sum
(
self
.
seq_lens
),
ctx
.
top_k_num
,
op_type
),
f
"Expected
{
num_tokens
}
tokens, but got
{
torch
.
sum
(
self
.
seq_lens
)
}
"
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
# assert self.seq_start_loc.shape[0] == num_seqs
## In down shrink case, each prompt corresponds to top_k_num sequences
assert
self
.
prompt_lora_mapping
.
shape
[
0
]
==
num_seqs
assert
self
.
get_num_tokens
(
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
],
ctx
.
top_k_num
,
op_type
)
def
to_device
(
self
,
device
:
str
):
"""
Transfer tensors to device if the tensors aren't already on the device
"""
def
to_device
(
tensor
:
torch
.
Tensor
):
if
tensor
.
device
!=
device
:
tensor
=
tensor
.
to
(
device
=
device
)
return
tensor
self
.
input
=
to_device
(
self
.
input
)
self
.
output
=
to_device
(
self
.
output
)
self
.
seq_lens
=
to_device
(
self
.
seq_lens
)
self
.
prompt_lora_mapping
=
to_device
(
self
.
prompt_lora_mapping
)
for
i
in
range
(
len
(
self
.
lora_weights_lst
)):
self
.
lora_weights_lst
[
i
]
=
to_device
(
self
.
lora_weights_lst
[
i
])
# LoRA meta
for
field_name
in
LoRAKernelMeta
.
__dataclass_fields__
:
field
=
getattr
(
self
.
lora_kernel_meta
,
field_name
)
assert
isinstance
(
field
,
torch
.
Tensor
)
setattr
(
self
.
lora_kernel_meta
,
field_name
,
to_device
(
field
)
if
field_name
!=
"no_lora_flag_cpu"
else
field
,
)
def
metadata
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
tuple
[
int
,
int
,
int
]:
"""
Return num_seqs, num_tokens and max_seq_len
"""
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
num_tokens
=
self
.
get_num_tokens
(
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
],
ctx
.
top_k_num
,
op_type
)
max_seq_len
=
torch
.
max
(
self
.
seq_lens
).
item
()
num_slices
=
len
(
self
.
lora_weights_lst
)
return
num_seqs
,
num_tokens
,
max_seq_len
,
num_slices
def
fused_moe_lora_data_prepare
(
self
,
block_size
:
int
,
token_lora_mapping
:
torch
.
Tensor
,
ctx
:
BenchmarkContext
,
):
def
moe_lora_align_block_size
(
topk_ids
:
torch
.
Tensor
,
token_lora_mapping
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
,
max_loras
:
int
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
pad_sorted_ids
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
if
pad_sorted_ids
:
max_num_tokens_padded
=
round_up
(
max_num_tokens_padded
,
block_size
)
sorted_ids
=
torch
.
empty
(
(
max_loras
*
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids
=
torch
.
empty
(
(
max_loras
*
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
num_tokens_post_pad
=
torch
.
empty
(
(
max_loras
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
moe_lora_align_block_size
(
topk_ids
,
token_lora_mapping
,
num_experts
,
block_size
,
max_loras
,
max_num_tokens_padded
,
max_num_m_blocks
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
if
expert_map
is
not
None
:
expert_ids
=
expert_map
[
expert_ids
]
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
num_tokens
=
ctx
.
batch_size
curr_topk_ids
=
torch
.
randint
(
0
,
ctx
.
num_experts
,
(
num_tokens
,
ctx
.
top_k_num
),
device
=
"cuda"
,
dtype
=
torch
.
int32
,
)
topk_weights
=
torch
.
randint
(
0
,
ctx
.
num_experts
,
(
num_tokens
,
ctx
.
top_k_num
),
device
=
"cuda"
,
dtype
=
torch
.
int32
,
)
(
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
)
=
(
moe_lora_align_block_size
(
topk_ids
=
curr_topk_ids
,
token_lora_mapping
=
token_lora_mapping
,
block_size
=
block_size
,
num_experts
=
ctx
.
num_experts
,
max_loras
=
ctx
.
num_loras
,
)
)
sorted_token_ids
=
sorted_token_ids_lora
.
view
(
ctx
.
num_loras
,
-
1
)
expert_ids
=
expert_ids_lora
.
view
(
ctx
.
num_loras
,
-
1
)
num_tokens_post_padded
=
num_tokens_post_padded_lora
return
(
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
)
def
as_lora_shrink_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape [num_tokens, hidden_size]
assert
len
(
i_shape
)
==
2
assert
i_shape
[
0
]
==
num_tokens
hidden_size
=
i_shape
[
1
]
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
assert
len
(
lw_shape
)
==
3
assert
lw_shape
[
2
]
==
hidden_size
lora_rank
=
lw_shape
[
1
]
# Expected output shape [num_slices, num_tokens, lora_rank]
assert
len
(
o_shape
)
==
3
assert
o_shape
==
(
num_slices
,
num_tokens
,
lora_rank
)
return
{
"inputs"
:
self
.
input
,
"lora_a_weights"
:
self
.
lora_weights_lst
,
"output_tensor"
:
self
.
output
,
"token_lora_mapping"
:
self
.
lora_kernel_meta
.
token_lora_mapping
,
"token_indices_sorted_by_lora_ids"
:
(
self
.
lora_kernel_meta
.
token_indices_sorted_by_lora_ids
),
"num_tokens_per_lora"
:
self
.
lora_kernel_meta
.
num_tokens_per_lora
,
"lora_token_start_loc"
:
self
.
lora_kernel_meta
.
lora_token_start_loc
,
"lora_ids"
:
self
.
lora_kernel_meta
.
active_lora_ids
,
"scaling"
:
1.0
,
"no_lora_flag_cpu"
:
self
.
lora_kernel_meta
.
no_lora_flag_cpu
,
}
def
as_lora_expand_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
add_inputs
:
bool
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape : [num_slices, num_tokens, lora_rank]
assert
len
(
i_shape
)
==
3
assert
i_shape
[
0
]
==
num_slices
assert
i_shape
[
1
]
==
num_tokens
lora_rank
=
i_shape
[
2
]
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
assert
len
(
lw_shape
)
==
3
assert
lw_shape
[
2
]
==
lora_rank
hidden_size
=
lw_shape
[
1
]
# Expected output shape : [num_tokens, hidden_size * num_slices]
assert
len
(
o_shape
)
==
2
assert
o_shape
==
(
num_tokens
,
hidden_size
*
num_slices
)
return
{
"inputs"
:
self
.
input
,
"lora_b_weights"
:
self
.
lora_weights_lst
,
"output_tensor"
:
self
.
output
,
"token_lora_mapping"
:
self
.
lora_kernel_meta
.
token_lora_mapping
,
"token_indices_sorted_by_lora_ids"
:
(
self
.
lora_kernel_meta
.
token_indices_sorted_by_lora_ids
),
"num_tokens_per_lora"
:
self
.
lora_kernel_meta
.
num_tokens_per_lora
,
"lora_token_start_loc"
:
self
.
lora_kernel_meta
.
lora_token_start_loc
,
"lora_ids"
:
self
.
lora_kernel_meta
.
active_lora_ids
,
"offset_start"
:
0
,
"add_inputs"
:
add_inputs
,
"no_lora_flag_cpu"
:
self
.
lora_kernel_meta
.
no_lora_flag_cpu
,
}
def
as_fused_moe_lora_shrink_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape : [num_tokens, hidden_size] for gate_up
# Expected input shape : [top_k_num * num_tokens, hidden_size] for down
assert
len
(
i_shape
)
==
2
assert
i_shape
[
0
]
==
num_tokens
hidden_size
=
i_shape
[
1
]
# Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size]
assert
len
(
lw_shape
)
==
4
assert
lw_shape
[
-
1
]
==
hidden_size
lora_rank
=
lw_shape
[
-
2
]
# Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert
len
(
o_shape
)
==
4
assert
(
o_shape
==
(
num_slices
,
num_tokens
//
ctx
.
top_k_num
,
ctx
.
top_k_num
,
lora_rank
)
if
op_type
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
]
else
o_shape
==
(
num_slices
,
num_tokens
,
ctx
.
top_k_num
,
lora_rank
)
)
kernel_config
=
get_lora_op_configs
(
op_type
.
name
.
lower
(),
max_loras
=
lw_shape
[
0
],
batch
=
num_tokens
,
hidden_size
=
hidden_size
,
rank
=
lora_rank
,
num_slices
=
num_slices
,
add_inputs
=
False
,
)
(
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
)
=
(
self
.
fused_moe_lora_data_prepare
(
block_size
=
kernel_config
[
"BLOCK_SIZE_M"
],
token_lora_mapping
=
self
.
lora_kernel_meta
.
token_lora_mapping
,
ctx
=
ctx
,
)
)
return
{
"qcurr_hidden_states"
:
self
.
input
,
"lora_a_stacked"
:
self
.
lora_weights_lst
,
"a_intermediate_cache1"
:
self
.
output
,
"topk_weights"
:
topk_weights
,
"sorted_token_ids"
:
sorted_token_ids
,
"expert_ids"
:
expert_ids
,
"num_tokens_post_padded"
:
num_tokens_post_padded
,
"token_lora_mapping"
:
self
.
lora_kernel_meta
.
token_lora_mapping
,
"top_k_num"
:
ctx
.
top_k_num
,
"device"
:
self
.
input
.
device
,
"N"
:
lora_rank
,
"M"
:
topk_weights
.
shape
[
0
],
"EM"
:
sorted_token_ids
.
shape
[
1
],
"K"
:
self
.
input
.
shape
[
1
],
"num_tokens"
:
num_tokens
,
"num_experts"
:
ctx
.
num_experts
,
"num_slices"
:
num_slices
,
"shrink_block_size_m"
:
kernel_config
[
"BLOCK_SIZE_M"
],
"shrink_block_size_n"
:
kernel_config
[
"BLOCK_SIZE_N"
],
"shrink_block_size_k"
:
kernel_config
[
"BLOCK_SIZE_K"
],
"shrink_group_size_m"
:
kernel_config
[
"GROUP_SIZE_M"
],
"shrink_num_warps"
:
kernel_config
[
"NUM_WARPS"
],
"shrink_num_stages"
:
kernel_config
[
"NUM_STAGES"
],
"shrink_split_k"
:
kernel_config
.
get
(
"SPLIT_K"
,
1
),
"mul_routed_weight"
:
op_type
.
is_fused_moe_lora_down_fn
(),
}
def
as_fused_moe_lora_expand_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert
len
(
i_shape
)
==
4
assert
i_shape
[
0
]
==
num_slices
assert
i_shape
[
1
]
==
num_tokens
lora_rank
=
i_shape
[
-
1
]
# Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank]
assert
len
(
lw_shape
)
==
4
assert
lw_shape
[
-
1
]
==
lora_rank
hidden_size
=
lw_shape
[
-
2
]
# Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices]
assert
len
(
o_shape
)
==
3
assert
o_shape
==
(
num_tokens
,
ctx
.
top_k_num
,
hidden_size
*
num_slices
)
kernel_config
=
get_lora_op_configs
(
op_type
.
name
.
lower
(),
max_loras
=
lw_shape
[
0
],
batch
=
num_tokens
,
hidden_size
=
hidden_size
,
rank
=
lora_rank
,
num_slices
=
num_slices
,
add_inputs
=
False
,
)
(
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
)
=
(
self
.
fused_moe_lora_data_prepare
(
block_size
=
kernel_config
[
"BLOCK_SIZE_M"
],
token_lora_mapping
=
self
.
lora_kernel_meta
.
token_lora_mapping
,
ctx
=
ctx
,
)
)
return
{
"a_intermediate_cache1"
:
self
.
input
,
"lora_b_stacked"
:
self
.
lora_weights_lst
,
"output"
:
self
.
output
,
"topk_weights"
:
topk_weights
,
"sorted_token_ids"
:
sorted_token_ids
,
"expert_ids"
:
expert_ids
,
"num_tokens_post_padded"
:
num_tokens_post_padded
,
"token_lora_mapping"
:
self
.
lora_kernel_meta
.
token_lora_mapping
,
"top_k_num"
:
ctx
.
top_k_num
,
"device"
:
self
.
input
.
device
,
"N"
:
lora_rank
,
"M"
:
topk_weights
.
shape
[
0
],
"EM"
:
sorted_token_ids
.
shape
[
1
],
"K"
:
self
.
input
.
shape
[
1
],
"num_tokens"
:
num_tokens
,
"num_experts"
:
ctx
.
num_experts
,
"num_slices"
:
num_slices
,
"max_lora_rank"
:
lora_rank
,
"w1_output_dim_size"
:
lw_shape
[
2
],
"expand_block_size_m"
:
kernel_config
[
"BLOCK_SIZE_M"
],
"expand_block_size_n"
:
kernel_config
[
"BLOCK_SIZE_N"
],
"expand_block_size_k"
:
kernel_config
[
"BLOCK_SIZE_K"
],
"expand_group_size_m"
:
kernel_config
[
"GROUP_SIZE_M"
],
"expand_num_warps"
:
kernel_config
[
"NUM_WARPS"
],
"expand_num_stages"
:
kernel_config
[
"NUM_STAGES"
],
"expand_split_k"
:
kernel_config
.
get
(
"SPLIT_K"
,
1
),
"mul_routed_weight"
:
op_type
.
is_fused_moe_lora_down_fn
(),
}
def
bench_fn_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
add_inputs
:
bool
|
None
=
None
)
->
dict
[
str
,
Any
]:
if
op_type
.
is_shrink_fn
()
or
op_type
.
is_fused_moe_lora_fn
():
assert
add_inputs
is
None
else
:
assert
add_inputs
is
not
None
if
op_type
==
OpType
.
LORA_SHRINK
:
return
self
.
as_lora_shrink_kwargs
(
ctx
,
op_type
)
if
op_type
==
OpType
.
LORA_EXPAND
:
return
self
.
as_lora_expand_kwargs
(
ctx
,
op_type
,
add_inputs
)
if
op_type
.
is_fused_moe_lora_shrink_fn
():
return
self
.
as_fused_moe_lora_shrink_kwargs
(
ctx
,
op_type
)
if
op_type
.
is_fused_moe_lora_expand_fn
():
return
self
.
as_fused_moe_lora_expand_kwargs
(
ctx
,
op_type
)
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
def
test_correctness
(
self
,
op_type
:
OpType
,
expand_fn_add_inputs
:
bool
|
None
)
->
bool
:
"""
Test correctness of op_type implementation against a grouped gemm
reference implementation.
"""
seq_lens_cpu
=
self
.
seq_lens
.
to
(
device
=
"cpu"
)
prompt_lora_mapping_cpu
=
self
.
prompt_lora_mapping
.
to
(
device
=
"cpu"
)
ref_output
=
self
.
output
.
clone
()
self
.
output
.
zero_
()
op_type
.
bench_fn
()(
**
self
.
bench_fn_kwargs
(
op_type
,
expand_fn_add_inputs
))
op_type
.
run_ref_group_gemm
(
ref_output
,
self
.
input
,
self
.
lora_weights_lst
,
seq_lens_cpu
=
seq_lens_cpu
,
prompt_lora_mapping_cpu
=
prompt_lora_mapping_cpu
,
scaling
=
1.0
,
add_inputs
=
expand_fn_add_inputs
,
)
rtol
,
atol
=
{
torch
.
float16
:
(
6e-2
,
6e-2
),
torch
.
bfloat16
:
(
6e-2
,
6e-2
),
torch
.
float32
:
(
1e-2
,
1e-2
),
}[
self
.
output
.
dtype
]
return
torch
.
allclose
(
ref_output
,
self
.
output
,
rtol
=
rtol
,
atol
=
atol
)
def
bench_optype
(
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
op_type
:
OpType
,
cuda_graph_nops
:
int
|
None
=
None
,
expand_fn_add_inputs
:
bool
|
None
=
None
,
test_correctness
:
bool
=
False
,
)
->
TMeasurement
:
assert
arg_pool_size
>=
1
if
op_type
.
is_shrink_fn
()
or
op_type
.
is_fused_moe_lora_fn
():
assert
expand_fn_add_inputs
is
None
else
:
assert
expand_fn_add_inputs
is
not
None
# BenchmarkContext -> BenchmarkTensors
bench_tensors
:
list
[
BenchmarkTensors
]
=
[
BenchmarkTensors
.
make
(
ctx
,
op_type
)
for
_
in
range
(
arg_pool_size
)
]
for
bt
in
bench_tensors
:
bt
.
sanity_check
(
ctx
,
op_type
)
# Test correctness of our implementation.
if
test_correctness
:
assert
op_type
in
[
OpType
.
LORA_SHRINK
,
OpType
.
LORA_EXPAND
],
(
f
"Correctness testing is not supported for
{
op_type
.
name
}
."
)
assert
all
(
[
bt
.
test_correctness
(
ctx
,
op_type
,
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
)
# BenchmarkTensors -> dict (kwargs)
kwargs_list
=
[
bt
.
bench_fn_kwargs
(
ctx
,
op_type
,
add_inputs
=
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
# Clear LoRA optimization hash-maps.
_LORA_A_PTR_DICT
.
clear
()
_LORA_B_PTR_DICT
.
clear
()
_LORA_PTR_DICT
.
clear
()
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
for
kwargs
in
kwargs_list
:
op_type
.
bench_fn
()(
**
kwargs
)
torch
.
cuda
.
synchronize
()
# Merge into a single kwargs and qualify arguments as ArgPool
kwargs
=
{
k
:
ArgPool
([])
for
k
in
kwargs_list
[
0
]}
for
_kwargs
in
kwargs_list
:
for
k
,
v
in
_kwargs
.
items
():
kwargs
[
k
].
values
.
append
(
v
)
describe_args
=
(
f
"add_inputs=
{
expand_fn_add_inputs
}
"
if
expand_fn_add_inputs
is
not
None
else
""
)
description
=
f
"
{
op_type
.
name
}
(
{
describe_args
}
) (
{
bench_tensors
[
0
].
io_types
()
}
)"
cuda_graph_params
=
None
if
cuda_graph_nops
:
cuda_graph_params
=
CudaGraphBenchParams
(
cuda_graph_nops
)
timer
=
None
with
Bench
(
cuda_graph_params
,
ctx
.
bench_label
(),
ctx
.
bench_sublabel
(
op_type
),
description
,
op_type
.
bench_fn
(),
**
kwargs
,
)
as
bench
:
timer
=
bench
.
run
()
return
timer
def
bench_torch_mm
(
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
op_type
:
OpType
,
cuda_graph_nops
:
int
|
None
=
None
,
)
->
TMeasurement
:
"""
Benchmark basic torch.mm as a roofline.
When all the input tokens have the same LoRA ID, the LoRA kernels are just
a matmul. This torch.mm benchmark serves as a roofline for that case.
input op_type is used in determining the m, k, n dimensions for the matmul.
"""
batch_size
,
hidden_size
,
lora_rank
,
seq_length
,
dtype
=
(
ctx
.
batch_size
,
ctx
.
hidden_size
,
ctx
.
lora_rank
,
ctx
.
seq_length
,
ctx
.
dtype
,
)
m
,
k
,
n
=
op_type
.
mkn
(
batch_size
,
seq_length
,
hidden_size
,
lora_rank
)
# For a fairer comparison.
n
=
n
*
ctx
.
num_slices
# Get matmul input and output tensors for A x B = C
As
,
Bs
,
Cs
=
[],
[],
[]
for
_
in
range
(
arg_pool_size
):
As
.
append
(
torch
.
rand
((
m
,
k
),
dtype
=
dtype
).
to
(
"cuda"
))
Bs
.
append
(
torch
.
rand
((
n
,
k
),
dtype
=
dtype
).
to
(
"cuda"
).
t
())
Cs
.
append
(
torch
.
rand
((
m
,
n
),
dtype
=
dtype
).
to
(
"cuda"
))
# Make torch.mm kwargs
mm_kwargs
=
{
"input"
:
ArgPool
(
As
),
"mat2"
:
ArgPool
(
Bs
),
"out"
:
ArgPool
(
Cs
)}
description
=
(
f
"single-lora roofline using torch.mm (
{
dtype_to_str
(
dtype
)
}
"
f
"x
{
dtype_to_str
(
dtype
)
}
"
f
"=>
{
dtype_to_str
(
dtype
)
}
)"
)
cuda_graph_params
=
None
if
cuda_graph_nops
:
cuda_graph_params
=
CudaGraphBenchParams
(
cuda_graph_nops
)
with
Bench
(
cuda_graph_params
,
ctx
.
bench_label
(),
ctx
.
bench_sublabel
(
op_type
),
description
,
torch
.
mm
,
**
mm_kwargs
,
)
as
bench
:
return
bench
.
run
()
# runner
def
use_cuda_graph_recommendation
()
->
str
:
return
"""
Triton kernels have a significant launch overhead with
launched directly via python. This overhead is more noticeable
for small the problem sizes. For these cases, it is recommended
to use the script with `--cuda-graph-nops N` to benchmark N
consecutive invocations of the benchmarking operations from
inside a CUDA Graph. Note that the returned measurement is for N
invocations of the operation.
"""
def
print_timers
(
timers
:
list
[
TMeasurement
],
args
:
argparse
.
Namespace
|
None
=
None
):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
if
args
and
args
.
cuda_graph_nops
:
print
(
f
"Note : The timings reported above is for
{
args
.
cuda_graph_nops
}
"
"consecutive invocations of the benchmarking functions. "
f
"Please divide by
{
args
.
cuda_graph_nops
}
for single invocation "
"timings."
)
print
(
"Note on Comparison with torch.mm : The torch.mm numbers are "
"benchmark numbers of a simple matmul emulating the single lora "
"case. It is provided as a roofline for comparing our LoRA Kernel "
"implementations. It is expected that the LoRA kernels will be "
"slower than torch.mm in cases where num_loras is big. But for "
"small num_loras the goal should be to match the torch.mm numbers."
)
def
run
(
args
:
argparse
.
Namespace
,
bench_ctxs
:
list
[
BenchmarkContext
]):
if
args
.
cuda_graph_nops
is
not
None
:
assert
args
.
cuda_graph_nops
>
0
print
(
f
"Benchmarking
{
args
.
cuda_graph_nops
}
invocations inside a CUDA Graph"
)
else
:
print
(
f
"CUDA Graphs not enabled.
\n
{
use_cuda_graph_recommendation
()
}
"
)
timers
=
[]
for
bench_ctx
in
bench_ctxs
:
for
seq_len
in
args
.
seq_lengths
:
bench_ops
:
list
[
OpType
]
=
args
.
op_types
seq_len_timers
=
[]
for
bench_op
in
bench_ops
:
for
num_slices
in
bench_op
.
num_slices
():
_ctx
=
bench_ctx
.
with_seq_length
(
seq_len
).
with_num_slices
(
num_slices
)
# Benchmark torch.mm as a roofline
seq_len_timers
.
append
(
bench_torch_mm
(
_ctx
,
args
.
arg_pool_size
,
bench_op
,
args
.
cuda_graph_nops
)
)
# Benchmark bench_op
expand_fn_add_inputs
=
(
[
None
]
if
bench_op
.
is_shrink_fn
()
or
bench_op
.
is_fused_moe_lora_fn
()
else
args
.
expand_fn_add_inputs
)
for
add_input_arg
in
expand_fn_add_inputs
:
seq_len_timers
.
append
(
bench_optype
(
_ctx
,
args
.
arg_pool_size
,
bench_op
,
args
.
cuda_graph_nops
,
add_input_arg
,
args
.
test_correctness
,
)
)
print_timers
(
seq_len_timers
)
timers
.
extend
(
seq_len_timers
)
# Result stdout dump
print
(
"== All Results ===="
)
print_timers
(
timers
,
args
)
if
args
.
output_directory
:
# Result file dump
od
=
Path
(
args
.
output_directory
)
if
not
od
.
exists
():
od
.
mkdir
()
timestamp
=
int
(
time
.
time
())
pkl_file
=
od
/
f
"lora_bench-
{
timestamp
}
.pkl"
print
(
f
"Writing benchmarks to
{
pkl_file
}
"
)
with
open
(
pkl_file
,
"wb"
)
as
f
:
pickle
.
dump
(
timers
,
f
)
def
as_benchmark_contexts
(
hidden_sizes
:
list
[
int
],
lora_ranks
:
list
[
int
],
args
:
argparse
.
Namespace
)
->
list
[
BenchmarkContext
]:
ctxs
:
list
[
BenchmarkContext
]
=
[]
for
(
batch_size
,
hidden_size
,
lora_rank
,
num_loras
,
sort_by_lora_id
,
top_k_num
,
num_experts
,
)
in
product
(
# noqa
args
.
batch_sizes
,
list
(
hidden_sizes
),
lora_ranks
,
args
.
num_loras
,
args
.
sort_by_lora_id
,
args
.
top_k_nums
,
args
.
num_experts
,
):
ctxs
.
append
(
BenchmarkContext
(
batch_size
=
batch_size
,
hidden_size
=
hidden_size
,
lora_rank
=
lora_rank
,
num_loras
=
num_loras
,
num_active_loras
=
args
.
num_active_loras
if
args
.
num_active_loras
else
num_loras
,
# To be filled based on the OpType to benchmark
seq_length
=
None
,
sort_by_lora_id
=
sort_by_lora_id
,
dtype
=
args
.
dtype
,
top_k_num
=
top_k_num
,
num_experts
=
num_experts
,
# To be filled based on the OpType to benchmark
num_slices
=
None
,
)
)
return
ctxs
def
run_list_bench
(
args
:
argparse
.
Namespace
):
print
(
args
)
print
(
"List bench :
\n
"
f
" Hidden Sizes
{
args
.
hidden_sizes
}
"
f
" LoRA Ranks
{
args
.
lora_ranks
}
"
)
# Get all benchmarking contexts
bench_contexts
:
list
[
BenchmarkContext
]
=
as_benchmark_contexts
(
hidden_sizes
=
args
.
hidden_sizes
,
lora_ranks
=
args
.
lora_ranks
,
args
=
args
)
run
(
args
,
bench_contexts
)
def
run_range_bench
(
args
:
argparse
.
Namespace
):
print
(
args
)
hidden_sizes
=
list
(
range
(
args
.
hidden_sizes_start
,
args
.
hidden_sizes_end
+
1
,
args
.
hidden_sizes_increment
,
)
)
lora_ranks
=
list
(
range
(
args
.
lora_ranks_start
,
args
.
lora_ranks_end
+
1
,
args
.
lora_ranks_increment
)
)
print
(
f
"Range bench :
\n
Hidden Sizes
{
hidden_sizes
}
LoRA Ranks
{
lora_ranks
}
"
)
# Get all benchmarking contexts
bench_contexts
:
list
[
BenchmarkContext
]
=
as_benchmark_contexts
(
hidden_sizes
=
hidden_sizes
,
lora_ranks
=
lora_ranks
,
args
=
args
)
run
(
args
,
bench_contexts
)
def
run_model_bench
(
args
:
argparse
.
Namespace
):
print
(
args
)
def
hidden_sizes_from_model
(
model
:
str
,
tp_size
:
int
)
->
set
[
int
]:
hidden_sizes
=
set
()
for
KN
,
tp_split_dim
in
WEIGHT_SHAPES
[
model
]:
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
hidden_sizes
.
add
(
KN
[
1
])
return
hidden_sizes
# Get all hidden sizes
hidden_sizes
:
set
[
int
]
=
set
()
for
model_name
,
tp_size
in
product
(
args
.
models
,
args
.
tp_sizes
):
hidden_sizes
=
hidden_sizes
.
union
(
hidden_sizes_from_model
(
model_name
,
tp_size
))
print
(
f
"Model bench :
\n
Hidden Sizes
{
hidden_sizes
}
LoRA Ranks
{
args
.
lora_ranks
}
"
)
# Get all benchmarking contexts
bench_contexts
:
list
[
BenchmarkContext
]
=
as_benchmark_contexts
(
hidden_sizes
=
hidden_sizes
,
lora_ranks
=
args
.
lora_ranks
,
args
=
args
)
run
(
args
,
bench_contexts
)
if
__name__
==
"__main__"
:
def
to_torch_dtype
(
dt
):
if
dt
==
"torch.float16"
:
return
torch
.
float16
if
dt
==
"torch.bfloat16"
:
return
torch
.
bfloat16
raise
ValueError
(
"unsupported dtype"
)
def
get_bool
(
s
:
str
)
->
bool
:
return
s
.
lower
()
in
[
"true"
,
"1"
]
def
add_common_command_args
(
p
:
argparse
.
ArgumentParser
):
p
.
add_argument
(
"--dtype"
,
type
=
to_torch_dtype
,
required
=
True
,
help
=
"Available options are ['torch.float16', 'torch.bfloat16']"
,
)
p
.
add_argument
(
"--arg-pool-size"
,
type
=
int
,
default
=
32
,
help
=
"Run profiles with a pool of input/output/meta tensors instead"
"of simply reusing the same tensors for all runs. A bigger arg-pool"
"mitigates hardware caching effects during benchmarking."
,
)
p
.
add_argument
(
"--cuda-graph-nops"
,
type
=
int
,
help
=
(
"when set profiling is done using cudagraph, "
"with the given number of operations in a graph."
"Note that the measurement returned is the time "
"taken for N consecutive executions of the benchmarking "
"functions, where N is the value of this argument."
),
)
p
.
add_argument
(
"--num-loras"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_NUM_LORAS
)
p
.
add_argument
(
"--num-active-loras"
,
type
=
int
,
default
=
None
,
help
=
"Active LoRAs. When None, all LoRAs are active"
,
)
p
.
add_argument
(
"--sort-by-lora-id"
,
nargs
=
"+"
,
type
=
get_bool
,
default
=
DEFAULT_SORT_BY_LORA_IDS
,
)
p
.
add_argument
(
"--op-types"
,
nargs
=
"+"
,
type
=
OpType
.
from_str
,
default
=
list
(
OpType
)
)
p
.
add_argument
(
"--seq-lengths"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_SEQ_LENGTHS
)
p
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
p
.
add_argument
(
"--expand-fn-add-inputs"
,
nargs
=
"+"
,
type
=
get_bool
,
default
=
DEFAULT_EXPAND_FN_ADD_INPUTS
,
)
p
.
add_argument
(
"-o"
,
"--output-directory"
,
type
=
str
,
help
=
(
"Output directory to store a the list of benchmarking"
"TMeasurement objects as a pickle file"
),
)
p
.
add_argument
(
"--test-correctness"
,
action
=
"store_true"
,
help
=
(
"When enabled, the benchmarking functions are tested"
"for correctness before the actual benchmarking"
),
)
p
.
add_argument
(
"--top-k-nums"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TOP_K_NUMS
,
help
=
"Top-K values for MoE LoRA operations"
,
)
p
.
add_argument
(
"--num-experts"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_NUM_EXPERTS
,
help
=
"Number of experts for MoE LoRA operations"
,
)
parser
=
FlexibleArgumentParser
(
description
=
f
"""
Benchmark LoRA kernels:
{
use_cuda_graph_recommendation
()
}
list_bench example:
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
model_bench example:
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
range_bench example:
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
"""
,
# noqa: E501
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
,
required
=
True
)
list_parser
=
subparsers
.
add_parser
(
"list_bench"
)
list_parser
.
add_argument
(
"--hidden-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_HIDDEN_SIZES
)
list_parser
.
add_argument
(
"--lora-ranks"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_LORA_RANKS
)
add_common_command_args
(
list_parser
)
list_parser
.
set_defaults
(
func
=
run_list_bench
)
range_parser
=
subparsers
.
add_parser
(
"range_bench"
)
range_parser
.
add_argument
(
"--hidden-sizes-start"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--hidden-sizes-end"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--hidden-sizes-increment"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--lora-ranks-start"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--lora-ranks-end"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--lora-ranks-increment"
,
type
=
int
,
required
=
True
)
add_common_command_args
(
range_parser
)
range_parser
.
set_defaults
(
func
=
run_range_bench
)
model_parser
=
subparsers
.
add_parser
(
"model_bench"
)
model_parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
model_parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
model_parser
.
add_argument
(
"--lora-ranks"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_LORA_RANKS
)
add_common_command_args
(
model_parser
)
model_parser
.
set_defaults
(
func
=
run_model_bench
)
args
=
parser
.
parse_args
()
args
.
func
(
args
)
benchmarks/kernels/benchmark_machete.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
math
import
os
import
pickle
as
pkl
import
time
from
collections.abc
import
Callable
,
Iterable
from
dataclasses
import
dataclass
from
itertools
import
product
import
pandas
as
pd
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
marlin_permute_scales
,
marlin_zero_points
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_rows
,
quantize_weights
,
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-3-8b"
,
"meta-llama/Llama-2-70b-hf"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
DEFAULT_TP_SIZES
=
[
1
]
NVTX_PROFILE
=
os
.
environ
.
get
(
"NVTX_PROFILE"
,
False
)
if
NVTX_PROFILE
:
import
nvtx
def
terse_type_name
(
dt
):
return
{
torch
.
bfloat16
:
"bf16"
,
torch
.
float16
:
"fp16"
,
torch
.
int8
:
"int8"
,
torch
.
float8_e4m3fn
:
"fp8"
,
torch
.
float
:
"float"
,
torch
.
int
:
"int"
,
}[
dt
]
@
dataclass
class
BenchmarkTensors
:
w_ref
:
torch
.
Tensor
a
:
torch
.
Tensor
w_q
:
torch
.
Tensor
group_size
:
int
|
None
wtype
:
ScalarType
w_g_s
:
torch
.
Tensor
w_g_zp
:
torch
.
Tensor
|
None
w_ch_s
:
torch
.
Tensor
|
None
w_tok_s
:
torch
.
Tensor
|
None
@
dataclass
class
TypeConfig
:
act_type
:
torch
.
dtype
weight_type
:
ScalarType
output_type
:
torch
.
dtype
|
None
group_scale_type
:
torch
.
dtype
|
None
group_zero_type
:
torch
.
dtype
|
None
channel_scale_type
:
torch
.
dtype
|
None
token_scale_type
:
torch
.
dtype
|
None
def
rand_data
(
shape
,
dtype
=
torch
.
float16
,
scale
=
1
):
if
dtype
.
is_floating_point
:
return
(
scale
*
torch
.
rand
(
shape
,
device
=
"cuda"
)
-
0.3
).
to
(
dtype
)
else
:
return
torch
.
randint
(
-
15
,
15
,
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
def
quantize_and_pack
(
atype
:
torch
.
dtype
,
w
:
torch
.
Tensor
,
wtype
:
ScalarType
,
stype
:
torch
.
dtype
|
None
,
group_size
:
int
|
None
,
zero_points
:
bool
=
False
,
):
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
w_ref
,
w_q
,
w_s
,
w_zp
=
quantize_weights
(
w
,
wtype
,
group_size
=
group_size
,
zero_points
=
zero_points
,
# to match how the kernel applies zps
ref_zero_points_after_scales
=
True
,
)
w_q
=
pack_rows
(
w_q
,
wtype
.
size_bits
,
*
w_q
.
shape
)
return
w_ref
,
w_q
,
w_s
,
w_zp
def
create_bench_tensors
(
shape
:
tuple
[
int
,
int
,
int
],
types
:
TypeConfig
,
group_size
:
int
|
None
)
->
list
[
BenchmarkTensors
]:
m
,
n
,
k
=
shape
# we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb
num_weights
=
math
.
ceil
(
2
*
50
*
1024
**
2
*
8
/
(
k
*
n
*
types
.
weight_type
.
size_bits
)
)
a
=
rand_data
((
m
,
k
),
types
.
act_type
,
scale
=
5
)
benchmark_tensors
:
list
[
BenchmarkTensors
]
=
[]
for
_
in
range
(
num_weights
):
w
=
rand_data
((
k
,
n
),
types
.
act_type
,
scale
=
5
)
if
types
.
group_scale_type
is
not
None
:
w
=
w
.
to
(
types
.
group_scale_type
)
if
w
.
dtype
.
itemsize
==
1
:
w
=
w
.
to
(
torch
.
float16
)
w_ref
,
w_q_packed
,
w_s
,
w_zp
=
quantize_and_pack
(
a
.
dtype
,
w
,
types
.
weight_type
,
types
.
group_scale_type
,
group_size
,
types
.
group_zero_type
is
not
None
,
)
if
not
a
.
dtype
.
is_floating_point
:
aiinfo
=
torch
.
iinfo
(
a
.
dtype
)
w_ref
=
w_ref
.
round
().
clamp
(
aiinfo
.
min
,
aiinfo
.
max
)
w_ref
=
w_ref
.
to
(
torch
.
float32
)
w_ch_s
=
(
None
if
types
.
channel_scale_type
is
None
else
rand_data
((
n
,),
types
.
channel_scale_type
)
)
w_tok_s
=
(
None
if
types
.
token_scale_type
is
None
else
rand_data
((
m
,),
types
.
token_scale_type
)
)
benchmark_tensors
.
append
(
BenchmarkTensors
(
w_ref
=
w_ref
,
a
=
a
,
w_q
=
w_q_packed
,
wtype
=
types
.
weight_type
,
w_g_s
=
w_s
,
w_g_zp
=
w_zp
,
group_size
=
group_size
,
w_ch_s
=
w_ch_s
,
w_tok_s
=
w_tok_s
,
)
)
return
benchmark_tensors
def
torch_matmul_f16_create_bench_fn
(
bt
:
BenchmarkTensors
)
->
Callable
:
a
=
bt
.
a
w
=
bt
.
w_ref
.
to
(
bt
.
a
.
dtype
)
# use float reference tensor
if
a
.
dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]:
a
=
a
.
to
(
torch
.
float16
)
w
=
w
.
to
(
torch
.
float16
)
return
lambda
:
torch
.
matmul
(
a
,
w
)
def
cutlass_scaled_mm_create_bench_fn
(
bt
:
BenchmarkTensors
)
->
Callable
:
if
bt
.
w_ch_s
is
not
None
and
bt
.
w_tok_s
is
not
None
:
scale_a
=
bt
.
w_tok_s
.
to
(
torch
.
float32
)
scale_b
=
bt
.
w_ch_s
.
to
(
torch
.
float32
)
else
:
scale_a
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
bt
.
a
.
device
)
scale_b
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
bt
.
a
.
device
)
w_col_major
=
bt
.
w_ref
.
to
(
bt
.
a
.
dtype
).
t
().
contiguous
().
t
()
return
lambda
:
ops
.
cutlass_scaled_mm
(
bt
.
a
,
w_col_major
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
float16
)
def
marlin_create_bench_fn
(
bt
:
BenchmarkTensors
)
->
Callable
:
device
=
bt
.
a
.
device
workspace
=
MarlinWorkspace
(
bt
.
w_ref
.
shape
[
1
],
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
if
bt
.
w_g_zp
is
None
:
w_zp
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
else
:
w_zp
=
marlin_zero_points
(
bt
.
w_g_zp
,
bt
.
w_ref
.
shape
[
0
],
bt
.
w_ref
.
shape
[
1
],
bt
.
wtype
.
size_bits
)
if
bt
.
group_size
is
None
:
w_s
=
torch
.
tensor
([],
device
=
"cuda"
,
dtype
=
torch
.
half
)
else
:
w_s
=
marlin_permute_scales
(
bt
.
w_g_s
,
bt
.
w_ref
.
shape
[
0
],
bt
.
w_ref
.
shape
[
1
],
bt
.
group_size
)
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
w_q
=
ops
.
gptq_marlin_repack
(
bt
.
w_q
,
sort_indices
,
bt
.
w_ref
.
shape
[
0
],
bt
.
w_ref
.
shape
[
1
],
bt
.
wtype
.
size_bits
)
if
bt
.
a
.
dtype
.
is_floating_point
:
assert
bt
.
w_ch_s
is
None
assert
bt
.
w_tok_s
is
None
assert
bt
.
group_size
is
not
None
fn
=
lambda
:
ops
.
marlin_gemm
(
a
=
bt
.
a
,
c
=
None
,
b_q_weight
=
w_q
,
b_bias
=
None
,
b_scales
=
w_s
,
a_scales
=
None
,
global_scale
=
None
,
b_zeros
=
w_zp
,
g_idx
=
g_idx
,
perm
=
sort_indices
,
workspace
=
workspace
.
scratch
,
b_q_type
=
bt
.
wtype
,
size_m
=
bt
.
a
.
shape
[
0
],
size_n
=
bt
.
w_ref
.
shape
[
1
],
size_k
=
bt
.
w_ref
.
shape
[
0
],
is_k_full
=
True
,
is_zp_float
=
False
,
)
else
:
assert
bt
.
a
.
dtype
==
torch
.
int8
assert
bt
.
wtype
==
scalar_types
.
uint4b8
raise
NotImplementedError
(
"QQQ is not supported anymore"
)
return
fn
def
machete_create_bench_fn
(
bt
:
BenchmarkTensors
,
out_type
=
torch
.
dtype
,
schedule
=
None
)
->
Callable
:
w_q
=
bt
.
w_q
.
t
().
contiguous
().
t
()
# make col major
w_q
=
ops
.
machete_prepack_B
(
w_q
,
bt
.
a
.
dtype
,
bt
.
wtype
,
None
if
bt
.
w_g_s
is
None
else
bt
.
w_g_s
.
dtype
)
w_g_zp
=
bt
.
w_g_zp
if
w_g_zp
is
not
None
:
w_g_zp
=
-
1
*
bt
.
w_g_s
*
(
w_g_zp
.
to
(
bt
.
w_g_s
.
dtype
))
return
lambda
:
ops
.
machete_mm
(
a
=
bt
.
a
,
b_q
=
w_q
,
b_type
=
bt
.
wtype
,
b_group_scales
=
bt
.
w_g_s
,
b_group_zeros
=
w_g_zp
,
b_group_size
=
bt
.
group_size
,
b_channel_scales
=
bt
.
w_ch_s
,
a_token_scales
=
bt
.
w_tok_s
,
out_type
=
out_type
,
schedule
=
schedule
,
)
def
cutlass_w4a8_create_bench_fn
(
bt
:
BenchmarkTensors
,
out_type
=
torch
.
dtype
,
schedule
=
None
)
->
Callable
:
w_q
=
bt
.
w_q
.
t
().
contiguous
().
t
()
# make col major
w_q
=
ops
.
cutlass_encode_and_reorder_int4b
(
w_q
)
# expects fp8 scales
w_s
=
ops
.
cutlass_pack_scale_fp8
(
bt
.
w_g_s
.
to
(
torch
.
float8_e4m3fn
))
return
lambda
:
ops
.
cutlass_w4a8_mm
(
a
=
bt
.
a
,
b_q
=
w_q
,
b_group_scales
=
w_s
,
b_group_size
=
bt
.
group_size
,
b_channel_scales
=
bt
.
w_ch_s
,
a_token_scales
=
bt
.
w_tok_s
,
maybe_schedule
=
schedule
,
)
# impl
# bench
def
bench_fns
(
label
:
str
,
sub_label
:
str
,
description
:
str
,
fns
:
list
[
Callable
]):
min_run_time
=
1
if
not
NVTX_PROFILE
else
0.1
res
=
TBenchmark
.
Timer
(
stmt
=
"""
for fn in fns:
fn()
"""
,
globals
=
{
"fns"
:
fns
},
label
=
label
,
sub_label
=
sub_label
,
description
=
description
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
if
NVTX_PROFILE
:
with
(
nvtx
.
annotate
(
"mm-bench"
),
nvtx
.
annotate
(
f
"
{
label
}
|
{
sub_label
}
|
{
description
}
"
),
):
fns
[
0
]()
return
res
_SWEEP_SCHEDULES_RESULTS
:
pd
.
DataFrame
|
None
=
None
_SWEEP_SCHEDULES_RESULTS_CSV
:
str
|
None
=
None
def
bench
(
types
:
TypeConfig
,
group_size
:
int
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
,
sweep_schedules
:
bool
=
True
,
)
->
list
[
TMeasurement
]:
benchmark_tensors
=
create_bench_tensors
((
m
,
n
,
k
),
types
,
group_size
)
sub_label
+=
f
", L=
{
len
(
benchmark_tensors
)
}
"
name_type_string
=
f
"W
{
types
.
weight_type
}
"
+
f
"-A
{
terse_type_name
(
types
.
act_type
)
}
"
if
types
.
group_scale_type
is
not
None
:
name_type_string
+=
f
"-GS
{
terse_type_name
(
types
.
group_scale_type
)
}
"
if
types
.
group_zero_type
is
not
None
:
name_type_string
+=
f
"-GZ
{
terse_type_name
(
types
.
group_zero_type
)
}
"
if
group_size
is
not
None
:
name_type_string
+=
f
"-G
{
group_size
}
"
if
types
.
channel_scale_type
is
not
None
:
name_type_string
+=
f
"-CS
{
terse_type_name
(
types
.
channel_scale_type
)
}
"
if
types
.
token_scale_type
is
not
None
:
name_type_string
+=
f
"-TS
{
terse_type_name
(
types
.
token_scale_type
)
}
"
timers
=
[]
# pytorch impl
timers
.
append
(
bench_fns
(
label
,
sub_label
,
"torch.matmul (fp16)"
,
[
torch_matmul_f16_create_bench_fn
(
bt
)
for
bt
in
benchmark_tensors
],
)
)
if
types
.
act_type
==
torch
.
int8
or
types
.
act_type
==
torch
.
float8_e4m3fn
:
timers
.
append
(
bench_fns
(
label
,
sub_label
,
f
"cutlass_scaled_mm (
{
terse_type_name
(
types
.
act_type
)
}
)"
,
[
cutlass_scaled_mm_create_bench_fn
(
bt
)
for
bt
in
benchmark_tensors
],
)
)
if
types
.
act_type
!=
torch
.
float8_e4m3fn
:
timers
.
append
(
bench_fns
(
label
,
sub_label
,
f
"marlin (
{
name_type_string
}
)"
,
[
marlin_create_bench_fn
(
bt
)
for
bt
in
benchmark_tensors
],
)
)
# machete
timers
.
append
(
bench_fns
(
label
,
sub_label
,
f
"machete (
{
name_type_string
}
)"
,
[
machete_create_bench_fn
(
bt
,
out_type
=
types
.
output_type
)
for
bt
in
benchmark_tensors
],
)
)
# cutlass w4a8
if
types
.
act_type
==
torch
.
float8_e4m3fn
and
group_size
==
128
:
timers
.
append
(
bench_fns
(
label
,
sub_label
,
f
"cutlass w4a8 (
{
name_type_string
}
)"
,
[
cutlass_w4a8_create_bench_fn
(
bt
,
out_type
=
types
.
output_type
)
for
bt
in
benchmark_tensors
],
)
)
if
sweep_schedules
:
global
_SWEEP_SCHEDULES_RESULTS
print
(
"Finding best schedule for machete"
)
best
=
None
best_schedule
=
None
schedules
=
ops
.
machete_supported_schedules
(
a_type
=
types
.
act_type
,
b_type
=
types
.
weight_type
,
group_scales_type
=
types
.
group_scale_type
,
group_zeros_type
=
types
.
group_zero_type
,
token_scales_type
=
types
.
token_scale_type
,
channel_scales_type
=
types
.
channel_scale_type
,
out_type
=
types
.
output_type
,
)
if
schedules
is
None
or
len
(
schedules
)
==
0
:
raise
ValueError
(
"No schedules found to sweep"
)
for
schedule
in
reversed
(
schedules
):
schedule_M
=
int
(
schedule
.
split
(
"_"
)[
0
].
split
(
"x"
)[
1
])
# Prune known bad schedules
if
schedule_M
>=
2
*
max
(
m
,
16
)
or
schedule_M
<
m
//
4
:
continue
res
=
bench_fns
(
label
,
sub_label
,
"machete_best"
,
[
machete_create_bench_fn
(
bt
,
out_type
=
types
.
output_type
,
schedule
=
schedule
)
for
bt
in
benchmark_tensors
],
)
results_row
=
{
"M"
:
m
,
"K"
:
k
,
"N"
:
n
,
"group_size"
:
group_size
,
"schedule"
:
schedule
,
"median"
:
res
.
median
,
}
if
_SWEEP_SCHEDULES_RESULTS
is
None
:
_SWEEP_SCHEDULES_RESULTS
=
pd
.
DataFrame
(
columns
=
results_row
.
keys
())
_SWEEP_SCHEDULES_RESULTS
.
loc
[
len
(
_SWEEP_SCHEDULES_RESULTS
)]
=
results_row
print
(
f
"
{
res
.
median
:
5.5
}
"
,
schedule
)
if
not
best
or
res
.
median
<
best
.
median
:
best
=
res
best_schedule
=
schedule
print
(
"Best schedule:"
,
best_schedule
)
timers
.
append
(
best
)
return
timers
# runner
def
print_timers
(
timers
:
list
[
TMeasurement
]):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
def
run
(
args
,
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]])
->
Iterable
[
TMeasurement
]:
types
=
TypeConfig
(
act_type
=
args
.
act_type
,
weight_type
=
scalar_types
.
uint4b8
if
args
.
group_zero_type
is
None
else
scalar_types
.
uint4
,
output_type
=
args
.
out_type
,
group_scale_type
=
args
.
group_scale_type
,
group_zero_type
=
args
.
group_zero_type
,
channel_scale_type
=
args
.
channel_scale_type
,
token_scale_type
=
args
.
token_scale_type
,
)
results
:
list
[
TMeasurement
]
=
[]
for
m
,
k
,
n
in
MKNs
:
timers
=
bench
(
types
,
args
.
group_size
,
m
,
k
,
n
,
f
"
{
args
.
act_type
}
-gemm"
,
f
"MKN=(
{
m
}
x
{
k
}
x
{
n
}
)"
,
sweep_schedules
=
args
.
sweep_schedules
,
)
print_timers
(
timers
)
results
.
extend
(
timers
)
return
results
# output makers
def
make_output
(
data
:
list
[
TMeasurement
],
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]],
base_description
:
str
,
timestamp
=
None
,
):
print
(
f
"== All Results
{
base_description
}
===="
)
print_timers
(
data
)
# pickle all the results
timestamp
=
int
(
time
.
time
())
if
timestamp
is
None
else
timestamp
with
open
(
f
"
{
base_description
}
-
{
timestamp
}
.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
(
data
,
f
)
# argparse runners
def
run_square_bench
(
args
):
dim_sizes
=
list
(
range
(
args
.
dim_start
,
args
.
dim_end
+
1
,
args
.
dim_increment
))
MKNs
=
list
(
zip
(
dim_sizes
,
dim_sizes
,
dim_sizes
))
data
=
run
(
args
.
dtype
,
args
.
sweep_schedules
,
MKNs
)
make_output
(
data
,
MKNs
,
f
"square_bench-
{
args
.
dtype
}
"
)
def
run_range_bench
(
args
):
m_start
,
k_start
,
n_start
=
(
int
(
x
)
for
x
in
args
.
dim_start
.
split
(
","
))
m_end
,
k_end
,
n_end
=
(
int
(
x
)
for
x
in
args
.
dim_end
.
split
(
","
))
m_increment
,
k_increment
,
n_increment
=
(
int
(
x
)
for
x
in
args
.
dim_increment
.
split
(
","
)
)
Ms
=
list
(
range
(
m_start
,
m_end
+
1
,
m_increment
))
Ks
=
list
(
range
(
k_start
,
k_end
+
1
,
k_increment
))
Ns
=
list
(
range
(
n_start
,
n_end
+
1
,
n_increment
))
MKNs
=
list
(
product
(
Ms
,
Ks
,
Ns
))
data
=
run
(
args
.
dtype
,
args
.
sweep_schedules
,
MKNs
)
make_output
(
data
,
MKNs
,
f
"range_bench-
{
args
.
dtype
}
"
)
def
run_model_bench
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
def
model_shapes
(
model_name
:
str
,
tp_size
:
int
)
->
list
[
tuple
[
int
,
int
]]:
KNs
=
[]
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model_name
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
KNs
.
append
(
KN
)
return
KNs
model_bench_data
=
[]
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
for
model
,
tp_size
in
models_tps
:
Ms
=
args
.
batch_sizes
KNs
=
model_shapes
(
model
,
tp_size
)
MKNs
=
[]
for
m
in
Ms
:
for
k
,
n
in
KNs
:
MKNs
.
append
((
m
,
k
,
n
))
data
=
run
(
args
,
MKNs
)
model_bench_data
.
append
(
data
)
type_string
=
f
"
{
args
.
act_type
}
"
# Print all results
for
data
,
model_tp
in
zip
(
model_bench_data
,
models_tps
):
model
,
tp_size
=
model_tp
print
(
f
"== Results
{
type_string
}
{
model
}
-TP
{
tp_size
}
===="
)
print_timers
(
data
)
timestr
=
time
.
strftime
(
"%Y%m%d-%H%M%S"
)
all_results
=
[]
for
d
in
model_bench_data
:
all_results
.
extend
(
d
)
# pickle all data
with
open
(
f
"model_bench-
{
type_string
}
-
{
timestr
}
.pkl"
,
"wb"
)
as
f
:
args_dict
=
vars
(
args
)
args_dict
.
pop
(
"func"
)
pkl
.
dump
(
{
"args"
:
args_dict
,
"results"
:
all_results
,
},
f
,
)
if
__name__
==
"__main__"
:
def
to_torch_dtype
(
dt
):
return
{
"bfloat16"
:
torch
.
bfloat16
,
"float16"
:
torch
.
float16
,
"int8"
:
torch
.
int8
,
"float8_e4m3fn"
:
torch
.
float8_e4m3fn
,
"int"
:
torch
.
int
,
"float"
:
torch
.
float
,
}[
dt
]
class
ToTorchDtype
(
argparse
.
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
setattr
(
namespace
,
self
.
dest
,
to_torch_dtype
(
values
))
parser
=
FlexibleArgumentParser
(
description
=
"""
Benchmark Machete GEMM.
To run square GEMMs:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
"""
,
# noqa: E501
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
parser
.
add_argument
(
"--act-type"
,
action
=
ToTorchDtype
,
required
=
True
,
choices
=
[
"bfloat16"
,
"float16"
,
"int8"
,
"float8_e4m3fn"
],
)
parser
.
add_argument
(
"--group-scale-type"
,
action
=
ToTorchDtype
,
choices
=
[
"bfloat16"
,
"float16"
],
)
parser
.
add_argument
(
"--group-zero-type"
,
type
=
to_torch_dtype
,
choices
=
[
"bfloat16"
,
"float16"
],
)
parser
.
add_argument
(
"--channel-scale-type"
,
action
=
ToTorchDtype
,
choices
=
[
"float"
],
)
parser
.
add_argument
(
"--token-scale-type"
,
action
=
ToTorchDtype
,
choices
=
[
"float"
],
)
parser
.
add_argument
(
"--out-type"
,
action
=
ToTorchDtype
,
choices
=
[
"bfloat16"
,
"float16"
],
)
parser
.
add_argument
(
"--group-size"
,
type
=
int
,
help
=
"Available options are ['None', '-1', '128'], default=128"
,
default
=
128
,
)
parser
.
add_argument
(
"--sweep-schedules"
,
action
=
"store_true"
,
help
=
"Run a sweep over all supported schedules"
,
)
parser
.
add_argument
(
"--sweep-csv-out"
,
help
=
"CSV to store sweep results"
,
default
=
"sch_sweep_results.csv"
,
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
,
required
=
True
)
square_parser
=
subparsers
.
add_parser
(
"square_bench"
)
square_parser
.
add_argument
(
"--dim-start"
,
type
=
int
,
required
=
True
)
square_parser
.
add_argument
(
"--dim-end"
,
type
=
int
,
required
=
True
)
square_parser
.
add_argument
(
"--dim-increment"
,
type
=
int
,
required
=
True
)
square_parser
.
set_defaults
(
func
=
run_square_bench
)
range_parser
=
subparsers
.
add_parser
(
"range_bench"
)
range_parser
.
add_argument
(
"--dim-start"
,
type
=
str
,
required
=
True
,
help
=
"Start value for M,K,N as common separated list"
,
)
range_parser
.
add_argument
(
"--dim-end"
,
type
=
str
,
required
=
True
,
help
=
"End value (inclusive) for M,K,N as common separated list"
,
)
range_parser
.
add_argument
(
"--dim-increment"
,
type
=
str
,
required
=
True
,
help
=
"Increment value for M,K,N as common separated list"
,
)
range_parser
.
set_defaults
(
func
=
run_range_bench
)
model_parser
=
subparsers
.
add_parser
(
"model_bench"
)
model_parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
model_parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
model_parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
model_parser
.
set_defaults
(
func
=
run_model_bench
)
args
=
parser
.
parse_args
()
_SWEEP_SCHEDULES_RESULTS_CSV
=
args
.
sweep_csv_out
args
.
func
(
args
)
if
_SWEEP_SCHEDULES_RESULTS
is
not
None
:
_SWEEP_SCHEDULES_RESULTS
.
to_csv
(
_SWEEP_SCHEDULES_RESULTS_CSV
)
benchmarks/kernels/benchmark_marlin.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.utils.benchmark
as
benchmark
from
benchmark_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.allspark_utils
import
(
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
ALLSPARK_SUPPORTED_QUANT_TYPES
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
MARLIN_SUPPORTED_GROUP_SIZES
,
query_marlin_supported_quant_types
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
FP4_MARLIN_SUPPORTED_GROUP_SIZES
,
rand_marlin_weight_fp4_like
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
marlin_quant_fp8_torch
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
awq_marlin_quantize
,
marlin_quantize
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
gptq_quantize_weights
,
quantize_weights
,
sort_weights
,
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
def
bench_run
(
results
:
list
[
benchmark
.
Measurement
],
model
:
str
,
act_order
:
bool
,
is_k_full
:
bool
,
quant_type
:
ScalarType
,
group_size
:
int
,
size_m
:
int
,
size_k
:
int
,
size_n
:
int
,
):
label
=
"Quant Matmul"
sub_label
=
"{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})"
.
format
(
model
,
act_order
,
is_k_full
,
str
(
quant_type
),
group_size
,
size_m
,
size_k
,
size_n
)
print
(
f
"Testing:
{
sub_label
}
"
)
a
=
torch
.
randn
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
()
b
=
torch
.
rand
(
size_k
,
size_n
).
to
(
torch
.
half
).
cuda
()
has_zp
=
quant_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
if
act_order
and
(
group_size
==
-
1
or
group_size
==
size_k
or
has_zp
):
return
if
size_k
%
group_size
!=
0
:
return
repack_supported
=
group_size
in
MARLIN_SUPPORTED_GROUP_SIZES
allspark_supported
=
(
quant_type
in
ALLSPARK_SUPPORTED_QUANT_TYPES
and
group_size
==
-
1
and
not
act_order
and
is_k_full
)
def
gen_marlin_params
():
# Marlin quant
marlin_g_idx
=
marlin_sort_indices
=
marlin_zp
=
marlin_s2
=
None
if
quant_type
==
scalar_types
.
float4_e2m1f
:
if
group_size
!=
16
or
act_order
:
return
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
=
rand_marlin_weight_fp4_like
(
b
.
T
,
group_size
)
elif
quant_type
==
scalar_types
.
float8_e4m3fn
:
if
group_size
not
in
[
-
1
,
128
]
or
act_order
:
return
marlin_w_ref
,
marlin_q_w
,
marlin_s
=
marlin_quant_fp8_torch
(
b
.
T
,
group_size
)
elif
group_size
==
16
:
return
elif
has_zp
:
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
=
awq_marlin_quantize
(
b
,
quant_type
,
group_size
)
else
:
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_g_idx
,
marlin_sort_indices
,
_
=
(
marlin_quantize
(
b
,
quant_type
,
group_size
,
act_order
)
)
return
(
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
,
marlin_zp
,
marlin_g_idx
,
marlin_sort_indices
,
)
def
gen_repack_params
():
q_w_gptq
=
None
repack_sort_indices
=
None
if
repack_supported
:
(
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
)
=
gptq_quantize_weights
(
b
,
quant_type
,
group_size
,
act_order
)
q_w_gptq
=
gptq_pack
(
q_w
,
quant_type
.
size_bits
,
size_k
,
size_n
)
# For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing
repack_sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b
.
device
)
if
act_order
:
(
q_w
,
g_idx
,
repack_sort_indices
)
=
sort_weights
(
q_w
,
g_idx
)
return
q_w_gptq
,
repack_sort_indices
def
gen_allspark_params
():
qw_reorder
=
s_reorder
=
zp_reorder
=
sm_count
=
sm_version
=
(
CUBLAS_M_THRESHOLD
)
=
None
nonlocal
allspark_supported
if
allspark_supported
:
properties
=
torch
.
cuda
.
get_device_properties
(
b
.
device
.
index
)
sm_count
=
properties
.
multi_processor_count
sm_version
=
properties
.
major
*
10
+
properties
.
minor
supported_arch
=
sm_version
>=
80
and
sm_version
<
90
allspark_supported
=
allspark_supported
and
supported_arch
if
supported_arch
:
w_ref
,
qw
,
s
,
zp
=
quantize_weights
(
b
,
quant_type
,
group_size
,
has_zp
)
qw
=
qw
.
to
(
torch
.
uint8
)
qw_reorder
,
s_reorder
,
zp_reorder
=
ops
.
allspark_repack_weight
(
qw
,
s
,
zp
,
has_zp
)
CUBLAS_M_THRESHOLD
=
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
return
(
qw_reorder
,
s_reorder
,
zp_reorder
,
sm_count
,
sm_version
,
CUBLAS_M_THRESHOLD
,
)
(
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
,
marlin_zp
,
marlin_g_idx
,
marlin_sort_indices
,
)
=
gen_marlin_params
()
q_w_gptq
,
repack_sort_indices
=
gen_repack_params
()
qw_reorder
,
s_reorder
,
zp_reorder
,
sm_count
,
sm_version
,
CUBLAS_M_THRESHOLD
=
(
gen_allspark_params
()
)
# Prepare
marlin_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
globals
=
{
# Gen params
"quant_type"
:
quant_type
,
"group_size"
:
group_size
,
"size_m"
:
size_m
,
"size_n"
:
size_n
,
"size_k"
:
size_k
,
"a"
:
a
,
# Marlin params
"marlin_w_ref"
:
marlin_w_ref
,
"marlin_q_w"
:
marlin_q_w
,
"marlin_s"
:
marlin_s
,
"marlin_s2"
:
marlin_s2
,
"marlin_zp"
:
marlin_zp
,
"marlin_g_idx"
:
marlin_g_idx
,
"marlin_sort_indices"
:
marlin_sort_indices
,
"marlin_workspace"
:
marlin_workspace
,
"is_k_full"
:
is_k_full
,
# GPTQ params
"q_w_gptq"
:
q_w_gptq
,
"repack_sort_indices"
:
repack_sort_indices
,
# AllSpark W8A16 params
"qw_reorder"
:
qw_reorder
,
"s_reorder"
:
s_reorder
,
"zp_reorder"
:
zp_reorder
,
"sm_count"
:
sm_count
,
"sm_version"
:
sm_version
,
"CUBLAS_M_THRESHOLD"
:
CUBLAS_M_THRESHOLD
,
# Kernels
"marlin_gemm"
:
ops
.
marlin_gemm
,
"gptq_marlin_repack"
:
ops
.
gptq_marlin_repack
,
"allspark_w8a16_gemm"
:
ops
.
allspark_w8a16_gemm
,
}
min_run_time
=
1
# Warmup pytorch
for
_
in
range
(
5
):
torch
.
matmul
(
a
,
marlin_w_ref
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"torch.matmul(a, marlin_w_ref)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"pytorch_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"marlin_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"marlin_gemm_fp32"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
if
repack_supported
:
results
.
append
(
benchmark
.
Timer
(
stmt
=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_repack"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
if
allspark_supported
:
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"allspark_w8a16_gemm_fp32"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
def
main
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
results
:
list
[
benchmark
.
Measurement
]
=
[]
for
model
in
args
.
models
:
for
layer
in
WEIGHT_SHAPES
[
model
]:
size_k
=
layer
[
0
]
size_n
=
layer
[
1
]
if
len
(
args
.
limit_k
)
>
0
and
size_k
not
in
args
.
limit_k
:
continue
if
len
(
args
.
limit_n
)
>
0
and
size_n
not
in
args
.
limit_n
:
continue
for
act_order
in
ACT_ORDER_OPTS
:
if
(
len
(
args
.
limit_act_order
)
>
0
and
act_order
not
in
args
.
limit_act_order
):
continue
for
is_k_full
in
K_FULL_OPTS
:
if
(
len
(
args
.
limit_k_full
)
>
0
and
is_k_full
not
in
args
.
limit_k_full
):
continue
for
quant_type
in
query_marlin_supported_quant_types
():
if
(
len
(
args
.
limit_num_bits
)
>
0
and
quant_type
.
size_bits
not
in
args
.
limit_num_bits
):
continue
for
group_size
in
(
MARLIN_SUPPORTED_GROUP_SIZES
+
FP4_MARLIN_SUPPORTED_GROUP_SIZES
):
if
(
len
(
args
.
limit_group_size
)
>
0
and
group_size
not
in
args
.
limit_group_size
):
continue
# For act_order, the group_size must be less than
# size_k
if
act_order
and
(
group_size
==
size_k
or
group_size
==
-
1
):
continue
for
size_m
in
args
.
batch_sizes
:
bench_run
(
results
,
model
,
act_order
,
is_k_full
,
quant_type
,
group_size
,
size_m
,
size_k
,
size_n
,
)
compare
=
benchmark
.
Compare
(
results
)
compare
.
print
()
# For quick benchmarking use:
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
#
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark Marlin across specified models/shapes/batches"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-group-size"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-num-bits"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-act-order"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-k-full"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_mla_k_concat.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation
in MLA (Multi-head Latent Attention) prefill.
This validates that the optimization from commit 8d4142bd is beneficial across
various batch sizes, not just the originally tested batch size of 32768.
"""
import
time
from
collections.abc
import
Callable
import
torch
# DeepSeek-V3 MLA dimensions
NUM_HEADS
=
128
QK_NOPE_HEAD_DIM
=
128
PE_DIM
=
64
def
cat_method
(
k_nope
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Original torch.cat approach with expand."""
return
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
def
direct_copy_method
(
k_nope
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Optimized direct copy approach (avoids expand + cat overhead)."""
k
=
torch
.
empty
(
(
*
k_nope
.
shape
[:
-
1
],
k_nope
.
shape
[
-
1
]
+
k_pe
.
shape
[
-
1
]),
dtype
=
k_nope
.
dtype
,
device
=
k_nope
.
device
,
)
k
[...,
:
k_nope
.
shape
[
-
1
]]
=
k_nope
k
[...,
k_nope
.
shape
[
-
1
]
:]
=
k_pe
return
k
def
benchmark_method
(
method
:
Callable
,
k_nope
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
num_warmup
:
int
=
10
,
num_iters
:
int
=
100
,
)
->
float
:
"""Benchmark a concatenation method and return mean latency in ms."""
# Warmup
for
_
in
range
(
num_warmup
):
_
=
method
(
k_nope
,
k_pe
)
torch
.
cuda
.
synchronize
()
# Benchmark
start
=
time
.
perf_counter
()
for
_
in
range
(
num_iters
):
_
=
method
(
k_nope
,
k_pe
)
torch
.
cuda
.
synchronize
()
end
=
time
.
perf_counter
()
return
(
end
-
start
)
/
num_iters
*
1000
# Convert to ms
@
torch
.
inference_mode
()
def
run_benchmark
(
dtype
:
torch
.
dtype
,
dtype_name
:
str
):
"""Run benchmark for a specific dtype."""
torch
.
set_default_device
(
"cuda"
)
# Batch sizes to test (powers of 2 from 32 to 65536)
batch_sizes
=
[
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
,
65536
]
print
(
"="
*
80
)
print
(
"Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation"
)
print
(
"="
*
80
)
print
(
f
"Tensor shapes: k_nope=[B,
{
NUM_HEADS
}
,
{
QK_NOPE_HEAD_DIM
}
], "
f
"k_pe=[B, 1,
{
PE_DIM
}
]"
)
print
(
f
"dtype:
{
dtype_name
}
"
)
print
()
print
(
f
"
{
'Batch Size'
:
>
12
}
|
{
'cat (ms)'
:
>
10
}
|
{
'direct (ms)'
:
>
12
}
| "
f
"
{
'Speedup'
:
>
8
}
|
{
'Reduction'
:
>
10
}
"
)
print
(
"-"
*
70
)
results
=
[]
for
batch_size
in
batch_sizes
:
# Create input tensors (generate in float32 then convert for FP8 compatibility)
k_nope
=
torch
.
randn
(
batch_size
,
NUM_HEADS
,
QK_NOPE_HEAD_DIM
,
dtype
=
torch
.
float32
,
device
=
"cuda"
).
to
(
dtype
)
k_pe
=
torch
.
randn
(
batch_size
,
1
,
PE_DIM
,
dtype
=
torch
.
float32
,
device
=
"cuda"
).
to
(
dtype
)
# Benchmark both methods
cat_time
=
benchmark_method
(
cat_method
,
k_nope
,
k_pe
)
direct_time
=
benchmark_method
(
direct_copy_method
,
k_nope
,
k_pe
)
speedup
=
cat_time
/
direct_time
reduction
=
(
1
-
direct_time
/
cat_time
)
*
100
results
.
append
((
batch_size
,
cat_time
,
direct_time
,
speedup
,
reduction
))
print
(
f
"
{
batch_size
:
>
12
}
|
{
cat_time
:
>
10.3
f
}
|
{
direct_time
:
>
12.3
f
}
| "
f
"
{
speedup
:
>
7.2
f
}
x |
{
reduction
:
>
9.1
f
}
%"
)
print
(
"="
*
80
)
# Summary statistics
speedups
=
[
r
[
3
]
for
r
in
results
]
print
(
"
\n
Speedup summary:"
)
print
(
f
" Min:
{
min
(
speedups
):.
2
f
}
x"
)
print
(
f
" Max:
{
max
(
speedups
):.
2
f
}
x"
)
print
(
f
" Mean:
{
sum
(
speedups
)
/
len
(
speedups
):.
2
f
}
x"
)
# Find crossover point
crossover_batch
=
None
for
batch_size
,
_
,
_
,
speedup
,
_
in
results
:
if
speedup
>=
1.0
:
crossover_batch
=
batch_size
break
print
(
"
\n
Conclusion:"
)
if
crossover_batch
:
print
(
f
" - Direct copy becomes beneficial at batch size >=
{
crossover_batch
}
"
)
# Filter for large batches (>= 512 which is typical for prefill)
large_batch_speedups
=
[
r
[
3
]
for
r
in
results
if
r
[
0
]
>=
512
]
if
large_batch_speedups
:
avg_large
=
sum
(
large_batch_speedups
)
/
len
(
large_batch_speedups
)
print
(
f
" - For batch sizes >= 512: avg speedup =
{
avg_large
:.
2
f
}
x"
)
print
(
" - MLA prefill typically uses large batches, so optimization is effective"
)
return
results
@
torch
.
inference_mode
()
def
main
():
# Test bfloat16
print
(
"
\n
"
)
run_benchmark
(
torch
.
bfloat16
,
"bfloat16"
)
# Test float8_e4m3fn
print
(
"
\n
"
)
run_benchmark
(
torch
.
float8_e4m3fn
,
"float8_e4m3fn"
)
if
__name__
==
"__main__"
:
main
()
benchmarks/kernels/benchmark_moe.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
gc
import
json
import
os
import
time
from
contextlib
import
nullcontext
from
datetime
import
datetime
from
itertools
import
product
from
typing
import
Any
,
TypedDict
import
ray
import
torch
from
ray.experimental.tqdm_ray
import
tqdm
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.all2all_utils
import
(
maybe_make_prepare_finalize
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
_get_config_dtype_str
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
,
)
from
vllm.transformers_utils.config
import
get_config
from
vllm.triton_utils
import
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
set_random_seed
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
# Default interval for clearing Triton JIT cache during tuning
# Set to 0 to disable automatic cache clearing
_CACHE_CLEAR_INTERVAL_ENV
=
"VLLM_MOE_TUNE_CACHE_CLEAR_INTERVAL"
TRITON_CACHE_CLEAR_INTERVAL
=
int
(
os
.
environ
.
get
(
_CACHE_CLEAR_INTERVAL_ENV
,
"50"
))
def
clear_triton_cache
():
"""Clear Triton JIT compilation cache and Python/CUDA memory.
This helps prevent OOM during tuning with large models (many experts).
"""
# Force Python garbage collection
gc
.
collect
()
# Clear CUDA memory cache
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
# Try to clear Triton's runtime cache
try
:
if
(
hasattr
(
triton
,
"runtime"
)
and
hasattr
(
triton
.
runtime
,
"cache"
)
and
hasattr
(
triton
.
runtime
.
cache
,
"clear"
)
):
triton
.
runtime
.
cache
.
clear
()
except
ImportError
:
# Triton not installed, skip cache clearing
pass
except
AttributeError
:
# Triton version doesn't have expected cache API
pass
except
Exception
as
e
:
print
(
f
"Warning: Failed to clear Triton cache:
{
e
}
"
)
# Additional garbage collection after clearing caches
gc
.
collect
()
def
ensure_divisibility
(
numerator
,
denominator
,
text
):
"""Ensure that numerator is divisible by the denominator."""
assert
numerator
%
denominator
==
0
,
"{} {} is not divisible by tp {}."
.
format
(
text
,
numerator
,
denominator
)
class
BenchmarkConfig
(
TypedDict
):
BLOCK_SIZE_M
:
int
BLOCK_SIZE_N
:
int
BLOCK_SIZE_K
:
int
GROUP_SIZE_M
:
int
num_warps
:
int
num_stages
:
int
def
benchmark_config
(
config
:
BenchmarkConfig
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
=
False
,
num_iters
:
int
=
100
,
block_quant_shape
:
list
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
)
->
float
:
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
if
use_int4_w4a16
:
# Int4 packed weights: 2 int4 values per uint8 byte
# K dimension is packed (halved)
intermediate_size
=
shard_intermediate_size
//
2
# after silu_and_mul
w1
=
torch
.
randint
(
0
,
255
,
(
num_experts
,
shard_intermediate_size
,
hidden_size
//
2
,
# int4 packing
),
dtype
=
torch
.
uint8
,
)
w2
=
torch
.
randint
(
0
,
255
,
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
# int4 packing
),
dtype
=
torch
.
uint8
,
)
elif
use_int8_w8a16
:
w1
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
),
dtype
=
torch
.
int8
,
)
w2
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
),
dtype
=
torch
.
int8
,
)
else
:
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
)
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
w1_scale
=
None
w2_scale
=
None
a1_scale
=
None
a2_scale
=
None
if
use_int4_w4a16
:
if
block_quant_shape
is
None
:
raise
ValueError
(
"block_quant_shape is required for int4_w4a16"
)
group_size
=
block_quant_shape
[
1
]
# Scales shape: (E, N, K // group_size) in fp16
w1_scale
=
torch
.
rand
(
(
num_experts
,
shard_intermediate_size
,
hidden_size
//
group_size
),
dtype
=
dtype
,
)
w2_scale
=
torch
.
rand
(
(
num_experts
,
hidden_size
,
intermediate_size
//
group_size
),
dtype
=
dtype
,
)
elif
use_int8_w8a16
:
w1_scale
=
torch
.
randn
(
(
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
if
use_deep_gemm
:
# we use the default block shape for deepgemm
block_quant_shape
=
[
128
,
128
]
if
use_fp8_w8a8
:
if
block_quant_shape
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
E
=
num_experts
N
=
shard_intermediate_size
//
2
K
=
hidden_size
factor_for_scale
=
1e-2
n_tiles_w1
=
(
2
*
N
+
block_n
-
1
)
//
block_n
n_tiles_w2
=
(
K
+
block_n
-
1
)
//
block_n
k_tiles_w1
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles_w2
=
(
N
+
block_k
-
1
)
//
block_k
w1_scale
=
(
torch
.
rand
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
)
*
factor_for_scale
)
w2_scale
=
(
torch
.
rand
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
)
*
factor_for_scale
)
else
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
w1
=
w1
.
to
(
FP8_DTYPE
)
w2
=
w2
.
to
(
FP8_DTYPE
)
input_gating
=
torch
.
empty
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
from
vllm.model_executor.layers.fused_moe
import
override_config
if
use_fp8_w8a8
:
quant_dtype
=
torch
.
float8_e4m3fn
elif
use_int8_w8a16
:
quant_dtype
=
torch
.
int8
else
:
quant_dtype
=
None
quant_config
=
FusedMoEQuantConfig
.
make
(
quant_dtype
=
quant_dtype
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_quant_shape
,
weight_dtype
=
"int4"
if
use_int4_w4a16
else
None
,
)
deep_gemm_experts
=
None
if
use_deep_gemm
:
moe_config
=
(
FusedMoEConfig
(
num_experts
=
num_experts
,
experts_per_token
=
topk
,
hidden_dim
=
hidden_size
,
intermediate_size_per_partition
=
shard_intermediate_size
,
num_local_experts
=
num_experts
,
num_logical_experts
=
num_experts
,
activation
=
MoEActivation
.
SILU
,
moe_parallel_config
=
FusedMoEParallelConfig
.
make_no_parallel
(),
in_dtype
=
init_dtype
,
routing_method
=
RoutingMethodType
.
TopK
,
device
=
"cuda"
,
),
)
deep_gemm_experts
=
mk
.
FusedMoEKernel
(
prepare_finalize
=
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
quant_config
,
allow_new_interface
=
True
,
use_monolithic
=
False
,
),
fused_experts
=
TritonOrDeepGemmExperts
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
),
inplace
=
not
disable_inplace
(),
)
with
override_config
(
config
):
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
x
,
input_gating
,
topk
,
renormalize
=
not
use_deep_gemm
)
inplace
=
not
disable_inplace
()
if
use_deep_gemm
:
return
deep_gemm_experts
.
apply
(
x
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
=
MoEActivation
.
SILU
,
global_num_experts
=
num_experts
,
apply_router_weight_on_input
=
False
,
expert_map
=
False
,
)
return
fused_experts
(
x
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
inplace
,
quant_config
=
quant_config
,
)
# JIT compilation & warmup
run
()
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
()
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
prepare
(
i
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
def
get_rocm_tuning_space
(
use_fp16
):
block_mn_range
=
[
16
,
32
,
64
,
128
,
256
]
block_k_range
=
[
16
,
32
,
64
,
128
,
256
]
if
not
use_fp16
:
block_k_range
.
remove
(
16
)
# BLOCK_K=16 not supported for fp8
num_warps_range
=
[
1
,
2
,
4
,
8
]
group_m_range
=
[
1
,
4
,
8
,
16
,
32
]
num_stage_range
=
[
2
]
waves_per_eu_range
=
[
0
,
1
,
2
,
4
]
matrix_instr_nonkdim_range
=
[
16
,
32
]
if
use_fp16
else
[]
kpack_range
=
[
1
,
2
]
if
use_fp16
else
[]
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_mn_range
,
"BLOCK_SIZE_N"
:
block_mn_range
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stage_range
,
"waves_per_eu"
:
waves_per_eu_range
,
}
if
use_fp16
:
param_ranges
[
"matrix_instr_nonkdim"
]
=
matrix_instr_nonkdim_range
param_ranges
[
"kpack"
]
=
kpack_range
return
param_ranges
def
get_configs_compute_bound
(
use_fp16
,
block_quant_shape
)
->
list
[
dict
[
str
,
int
]]:
configs
:
list
[
BenchmarkConfig
]
=
[]
if
current_platform
.
is_rocm
():
param_ranges
=
get_rocm_tuning_space
(
use_fp16
)
else
:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
block_m_range
=
[
16
,
32
,
64
,
128
,
256
]
block_n_range
=
[
32
,
64
,
128
,
256
]
block_k_range
=
[
64
,
128
,
256
]
num_warps_range
=
[
4
,
8
]
group_m_range
=
[
1
,
16
,
32
,
64
]
num_stage_range
=
[
2
,
3
,
4
,
5
]
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_m_range
,
"BLOCK_SIZE_N"
:
block_n_range
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stage_range
,
}
keys
,
values
=
zip
(
*
param_ranges
.
items
())
for
config_values
in
product
(
*
values
):
config
=
dict
(
zip
(
keys
,
config_values
))
configs
.
append
(
config
)
# Remove configs that are not compatible with fp8 block quantization
# BLOCK_SIZE_K must be a multiple of block_k
# BLOCK_SIZE_N must be a multiple of block_n
if
block_quant_shape
is
not
None
and
not
use_fp16
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
for
config
in
configs
[:]:
if
(
config
[
"BLOCK_SIZE_K"
]
%
block_k
!=
0
or
config
[
"BLOCK_SIZE_N"
]
%
block_n
!=
0
):
configs
.
remove
(
config
)
return
configs
def
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
,
topk
):
N1
,
K1
=
shard_intermediate_size
,
hidden_size
N2
,
K2
=
hidden_size
,
shard_intermediate_size
//
2
pruned_space_1
=
prune_rocm_configs
(
num_tokens
*
topk
,
N1
,
K1
,
search_space
,
is_fp16
)
pruned_space_2
=
prune_rocm_configs
(
num_tokens
*
topk
,
N2
,
K2
,
search_space
,
is_fp16
)
search_space
=
merge_unique_dicts
(
pruned_space_1
,
pruned_space_2
)
return
search_space
# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def
prune_rocm_configs
(
M
,
N
,
K
,
configs
,
is_fp16
=
True
):
pruned_configs
=
[]
elemBytes_a
=
2
if
is_fp16
else
1
elemBytes_b
=
2
if
is_fp16
else
1
mfma
=
16
if
M
<
32
or
N
<
32
else
32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm
=
False
if
M
>=
2048
and
N
>=
2048
:
large_gemm
=
True
for
config
in
configs
:
BLOCK_SIZE_M
=
config
.
get
(
"BLOCK_SIZE_M"
)
BLOCK_SIZE_N
=
config
.
get
(
"BLOCK_SIZE_N"
)
BLOCK_SIZE_K
=
config
.
get
(
"BLOCK_SIZE_K"
)
num_warps
=
config
.
get
(
"num_warps"
)
if
is_fp16
:
matrix_instr_nonkdim
=
config
.
get
(
"matrix_instr_nonkdim"
)
if
matrix_instr_nonkdim
>
mfma
:
continue
if
mfma
==
4
and
BLOCK_SIZE_K
<
64
:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if
BLOCK_SIZE_M
*
BLOCK_SIZE_N
<
64
:
continue
SPLIT_K
=
config
.
get
(
"SPLIT_K"
,
1
)
GROUP_M
=
config
.
get
(
"GROUP_SIZE_M"
)
if
is_fp16
:
if
(
matrix_instr_nonkdim
>
BLOCK_SIZE_M
or
matrix_instr_nonkdim
>
BLOCK_SIZE_N
):
continue
if
matrix_instr_nonkdim
>=
M
and
matrix_instr_nonkdim
!=
BLOCK_SIZE_M
:
continue
if
matrix_instr_nonkdim
>=
N
and
matrix_instr_nonkdim
!=
BLOCK_SIZE_N
:
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if
M
*
2
<
BLOCK_SIZE_M
and
BLOCK_SIZE_M
!=
16
:
continue
if
N
*
2
<
BLOCK_SIZE_N
and
BLOCK_SIZE_N
!=
16
:
continue
# skip large split_k when not necessary
if
SPLIT_K
!=
1
and
not
need_split_k
(
M
,
N
,
K
):
continue
# skip split_k that leads to EVEN_K = false
leap
=
SPLIT_K
*
BLOCK_SIZE_K
modv
=
K
%
leap
if
modv
!=
0
:
continue
# skip large GROUP_M
if
GROUP_M
*
BLOCK_SIZE_M
>
M
and
GROUP_M
!=
1
:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS
=
(
BLOCK_SIZE_K
*
BLOCK_SIZE_M
*
elemBytes_a
+
BLOCK_SIZE_K
*
BLOCK_SIZE_N
*
elemBytes_b
)
if
LDS
>
65536
:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if
large_gemm
:
if
BLOCK_SIZE_M
<
64
or
BLOCK_SIZE_N
<
64
:
continue
if
BLOCK_SIZE_K
<
64
:
continue
if
num_warps
<
4
:
continue
pruned_configs
.
append
(
config
)
return
pruned_configs
def
need_split_k
(
SIZE_M
,
SIZE_N
,
SIZE_K
):
return
(
SIZE_M
<
64
or
SIZE_N
<
64
)
and
SIZE_K
>
1024
def
merge_unique_dicts
(
list1
,
list2
):
result
=
[]
combined_list
=
list1
.
copy
()
combined_list
.
extend
(
list2
)
for
dictionary
in
combined_list
:
if
dictionary
not
in
result
:
result
.
append
(
dictionary
)
return
result
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
set_random_seed
(
seed
)
self
.
seed
=
seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self
.
device_id
=
int
(
ray
.
get_gpu_ids
()[
0
])
def
benchmark
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
=
False
,
block_quant_shape
:
list
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
# local import to allow serialization by ray
set_random_seed
(
self
.
seed
)
dtype_str
=
_get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int4_w4a16
=
use_int4_w4a16
,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
block_n
=
block_quant_shape
[
0
]
if
block_quant_shape
else
None
block_k
=
block_quant_shape
[
1
]
if
block_quant_shape
else
None
op_config
=
get_moe_configs
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
,
block_n
,
block_k
)
if
op_config
is
None
:
config
=
get_default_config
(
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype_str
,
block_quant_shape
,
)
else
:
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
num_iters
=
100
,
block_quant_shape
=
block_quant_shape
,
use_deep_gemm
=
use_deep_gemm
,
)
return
config
,
kernel_time
def
tune
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
search_space
:
list
[
dict
[
str
,
int
]],
block_quant_shape
:
list
[
int
],
use_deep_gemm
:
bool
,
)
->
dict
[
str
,
int
]:
# local import to allow serialization by ray
from
vllm.platforms
import
current_platform
best_config
=
None
best_time
=
float
(
"inf"
)
if
current_platform
.
is_rocm
():
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
or
use_int4_w4a16
)
search_space
=
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
,
topk
,
)
need_device_guard
=
False
if
current_platform
.
is_rocm
():
visible_device
=
os
.
environ
.
get
(
"ROCR_VISIBLE_DEVICES"
,
None
)
if
visible_device
!=
f
"
{
self
.
device_id
}
"
:
need_device_guard
=
True
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
need_device_guard
else
nullcontext
():
for
idx
,
config
in
enumerate
(
tqdm
(
search_space
)):
try
:
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
num_iters
=
20
,
block_quant_shape
=
block_quant_shape
,
use_deep_gemm
=
use_deep_gemm
,
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
continue
if
kernel_time
<
best_time
:
best_time
=
kernel_time
best_config
=
config
# Periodically clear Triton JIT cache to prevent OOM
# This is especially important for large models with many experts
if
(
TRITON_CACHE_CLEAR_INTERVAL
>
0
and
idx
>
0
and
idx
%
TRITON_CACHE_CLEAR_INTERVAL
==
0
):
clear_triton_cache
()
# Final cleanup after tuning completes
clear_triton_cache
()
now
=
datetime
.
now
()
print
(
f
"
{
now
.
ctime
()
}
] Completed tuning for batch_size=
{
num_tokens
}
"
)
assert
best_config
is
not
None
return
best_config
def
sort_config
(
config
:
BenchmarkConfig
)
->
BenchmarkConfig
:
return
{
"BLOCK_SIZE_M"
:
config
[
"BLOCK_SIZE_M"
],
"BLOCK_SIZE_N"
:
config
[
"BLOCK_SIZE_N"
],
"BLOCK_SIZE_K"
:
config
[
"BLOCK_SIZE_K"
],
"GROUP_SIZE_M"
:
config
[
"GROUP_SIZE_M"
],
"num_warps"
:
config
[
"num_warps"
],
"num_stages"
:
config
[
"num_stages"
],
**
(
{
"waves_per_eu"
:
config
[
"waves_per_eu"
]}
if
"waves_per_eu"
in
config
else
{}
),
**
(
{
"matrix_instr_nonkdim"
:
config
[
"matrix_instr_nonkdim"
]}
if
"matrix_instr_nonkdim"
in
config
else
{}
),
**
({
"kpack"
:
config
[
"kpack"
]}
if
"kpack"
in
config
else
{}),
**
({
"SPLIT_K"
:
config
[
"SPLIT_K"
]}
if
"SPLIT_K"
in
config
else
{}),
}
def
save_configs
(
configs
:
dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
block_quant_shape
:
list
[
int
],
save_dir
:
str
,
)
->
None
:
dtype_str
=
_get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int4_w4a16
=
use_int4_w4a16
,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename
=
get_config_file_name
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
,
block_quant_shape
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
filename
=
os
.
path
.
join
(
save_dir
,
filename
)
print
(
f
"Writing best config to
{
filename
}
..."
)
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
({
"triton_version"
:
triton
.
__version__
,
**
configs
},
f
,
indent
=
4
)
f
.
write
(
"
\n
"
)
def
get_compressed_tensors_block_structure
(
config
,
default_value
=
None
):
config_groups
=
config
.
get
(
"config_groups"
,
{})
if
len
(
config_groups
)
!=
1
:
return
default_value
group
=
next
(
iter
(
config_groups
.
values
()))
weights
=
group
.
get
(
"weights"
,
{})
block_structure
=
weights
.
get
(
"block_structure"
,
default_value
)
return
block_structure
def
get_weight_block_size_safety
(
config
,
default_value
=
None
):
quantization_config
=
getattr
(
config
,
"quantization_config"
,
{})
if
isinstance
(
quantization_config
,
dict
):
if
"weight_block_size"
in
quantization_config
:
return
quantization_config
[
"weight_block_size"
]
return
get_compressed_tensors_block_structure
(
quantization_config
,
default_value
)
return
default_value
def
get_model_params
(
config
):
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
,
"DeepseekV32ForCausalLM"
,
"GlmMoeDsaForCausalLM"
,
"Glm4MoeForCausalLM"
,
"Glm4MoeLiteForCausalLM"
,
"NemotronHForCausalLM"
,
"MistralLarge3ForCausalLM"
,
):
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
,
"Qwen3NextForCausalLM"
,
):
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
==
"Qwen3VLMoeForConditionalGeneration"
:
text_config
=
config
.
get_text_config
()
E
=
text_config
.
num_experts
topk
=
text_config
.
num_experts_per_tok
intermediate_size
=
text_config
.
moe_intermediate_size
hidden_size
=
text_config
.
hidden_size
elif
config
.
architectures
[
0
]
==
"HunYuanMoEV1ForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
moe_topk
[
0
]
intermediate_size
=
config
.
moe_intermediate_size
[
0
]
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
==
"Qwen3OmniMoeForConditionalGeneration"
:
E
=
config
.
thinker_config
.
text_config
.
num_experts
topk
=
config
.
thinker_config
.
text_config
.
num_experts_per_tok
intermediate_size
=
config
.
thinker_config
.
text_config
.
moe_intermediate_size
hidden_size
=
config
.
thinker_config
.
text_config
.
hidden_size
elif
config
.
architectures
[
0
]
==
"PixtralForConditionalGeneration"
:
# Pixtral can contain different LLM architectures,
# recurse to get their parameters
return
get_model_params
(
config
.
get_text_config
())
else
:
# Support for llama4
config
=
config
.
get_text_config
()
# Default: Mixtral.
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
hidden_size
=
config
.
hidden_size
return
E
,
topk
,
intermediate_size
,
hidden_size
def
get_quantization_group_size
(
config
)
->
int
|
None
:
"""Extract the quantization group size from the HF model config.
This reads directly from the HuggingFace config object (as returned by
``get_config()``), not from vLLM's quantization config classes.
Supports AWQ/GPTQ-style configs (direct 'group_size' key) and
compressed-tensors configs (nested inside 'config_groups').
"""
quantization_config
=
getattr
(
config
,
"quantization_config"
,
{})
if
not
isinstance
(
quantization_config
,
dict
):
return
None
# AWQ / GPTQ style: group_size is a top-level key
gs
=
quantization_config
.
get
(
"group_size"
)
if
gs
is
not
None
:
return
gs
# compressed-tensors style: group_size is nested in config_groups
config_groups
=
quantization_config
.
get
(
"config_groups"
,
{})
if
not
isinstance
(
config_groups
,
dict
):
return
None
for
group_cfg
in
config_groups
.
values
():
if
not
isinstance
(
group_cfg
,
dict
):
continue
weights
=
group_cfg
.
get
(
"weights"
,
{})
if
not
isinstance
(
weights
,
dict
):
continue
gs
=
weights
.
get
(
"group_size"
)
if
gs
is
not
None
:
return
gs
return
None
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
config
=
get_config
(
model
=
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
model_prefix
:
config
=
getattr
(
config
,
args
.
model_prefix
)
E
,
topk
,
intermediate_size
,
hidden_size
=
get_model_params
(
config
)
enable_ep
=
bool
(
args
.
enable_expert_parallel
)
if
enable_ep
:
ensure_divisibility
(
E
,
args
.
tp_size
,
"Number of experts"
)
E
=
E
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
else
:
ensure_divisibility
(
intermediate_size
,
args
.
tp_size
,
"intermediate_size"
)
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_int4_w4a16
=
args
.
dtype
==
"int4_w4a16"
block_quant_shape
=
get_weight_block_size_safety
(
config
)
if
use_int4_w4a16
:
group_size
=
get_quantization_group_size
(
config
)
if
group_size
is
None
:
raise
ValueError
(
"Could not determine group_size from model config. "
"The model's quantization_config must contain a 'group_size' "
"field (AWQ/GPTQ) or 'config_groups.*.weights.group_size' "
"(compressed-tensors)."
)
# For int4_w4a16, block_shape = [0, group_size]
# block_shape[0]=0 means no block quantization on N dimension
block_quant_shape
=
[
0
,
group_size
]
if
args
.
batch_size
is
None
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
,
]
else
:
batch_sizes
=
args
.
batch_size
use_deep_gemm
=
bool
(
args
.
use_deep_gemm
)
if
current_platform
.
is_rocm
()
and
"HIP_VISIBLE_DEVICES"
in
os
.
environ
:
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
logger
.
warning
(
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
)
val
=
os
.
environ
[
"HIP_VISIBLE_DEVICES"
]
os
.
environ
[
"ROCR_VISIBLE_DEVICES"
]
=
val
del
os
.
environ
[
"HIP_VISIBLE_DEVICES"
]
ray
.
init
()
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
def
_distribute
(
method
:
str
,
inputs
:
list
[
Any
])
->
list
[
Any
]:
outputs
=
[]
worker_idx
=
0
for
input_args
in
inputs
:
worker
=
workers
[
worker_idx
]
worker_method
=
getattr
(
worker
,
method
)
output
=
worker_method
.
remote
(
*
input_args
)
outputs
.
append
(
output
)
worker_idx
=
(
worker_idx
+
1
)
%
num_gpus
return
ray
.
get
(
outputs
)
if
args
.
tune
:
# int4_w4a16 weights are uint8-packed, not fp16; treat like fp8 for
# search space generation (no matrix_instr_nonkdim/kpack exploration).
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
or
use_int4_w4a16
)
# For int4_w4a16, the group_size constraint on BLOCK_SIZE_K does not
# apply: the gptq_awq kernel handles arbitrary BLOCK_SIZE_K regardless
# of group_size. Skip block_quant_shape filtering to keep the full
# search space (e.g. BLOCK_SIZE_K=64 with group_size=128).
tune_block_quant_shape
=
None
if
use_int4_w4a16
else
block_quant_shape
search_space
=
get_configs_compute_bound
(
is_fp16
,
tune_block_quant_shape
)
if
use_int4_w4a16
:
# SPLIT_K is a required kernel constexpr for gptq_awq kernel;
# only SPLIT_K=1 is used at runtime, so fix it during tuning.
for
cfg
in
search_space
:
cfg
[
"SPLIT_K"
]
=
1
print
(
f
"Start tuning over
{
len
(
search_space
)
}
configurations..."
)
if
use_deep_gemm
:
raise
ValueError
(
"Tuning with --use-deep-gemm is not supported as it only tunes Triton "
"kernels. Please remove the flag."
)
start
=
time
.
time
()
configs
=
_distribute
(
"tune"
,
[
(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
search_space
,
block_quant_shape
,
use_deep_gemm
,
)
for
batch_size
in
batch_sizes
],
)
best_configs
=
{
M
:
sort_config
(
config
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
}
save_configs
(
best_configs
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
block_quant_shape
,
args
.
save_dir
,
)
end
=
time
.
time
()
print
(
f
"Tuning took
{
end
-
start
:.
2
f
}
seconds"
)
else
:
outputs
=
_distribute
(
"benchmark"
,
[
(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
block_quant_shape
,
use_deep_gemm
,
)
for
batch_size
in
batch_sizes
],
)
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
print
(
f
"Batch size:
{
batch_size
}
, config:
{
config
}
"
)
print
(
f
"Kernel time:
{
kernel_time
:.
2
f
}
us"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
"--tensor-parallel-size"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--enable-expert-parallel"
,
"-enable-ep"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
,
"int4_w4a16"
],
default
=
"auto"
,
)
parser
.
add_argument
(
"--use-deep-gemm"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--save-dir"
,
type
=
str
,
default
=
"./"
,
help
=
"Directory to save tuned results"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
"+"
,
required
=
False
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--model-prefix"
,
type
=
str
,
required
=
False
)
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_moe_align_block_size.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
itertools
import
torch
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
,
)
from
vllm.triton_utils
import
triton
def
get_topk_ids
(
num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
)
->
torch
.
Tensor
:
return
torch
.
stack
(
[
torch
.
randperm
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[:
topk
]
for
_
in
range
(
num_tokens
)
]
)
# test configurations
num_tokens_range
=
[
1
,
16
,
256
,
4096
]
num_experts_range
=
[
16
,
64
,
224
,
256
,
280
,
512
]
topk_range
=
[
1
,
2
,
8
]
ep_size_range
=
[
1
,
8
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
,
ep_size_range
)
)
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
,
"ep_size"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
],
line_names
=
[
"vLLM"
],
plot_name
=
"moe-align-block-size-performance"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
num_experts
,
topk
,
ep_size
,
provider
):
"""Benchmark function for Triton."""
block_size
=
256
torch
.
cuda
.
manual_seed_all
(
0
)
topk_ids
=
get_topk_ids
(
num_tokens
,
num_experts
,
topk
)
e_map
=
None
if
ep_size
!=
1
:
local_e
=
num_experts
//
ep_size
e_ids
=
torch
.
randperm
(
num_experts
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
local_e
]
e_map
=
torch
.
full
((
num_experts
,),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
[
e_ids
]
=
torch
.
arange
(
local_e
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"vllm"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
moe_align_block_size
(
topk_ids
,
block_size
,
num_experts
,
e_map
,
ignore_invalid_experts
=
True
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num_experts"
,
type
=
int
,
default
=
64
,
choices
=
[
8
,
16
,
32
,
64
,
128
,
256
],
)
parser
.
add_argument
(
"--topk"
,
type
=
int
,
default
=
8
,
choices
=
[
2
,
4
,
8
],
help
=
"Top-k value for correctness check."
,
)
args
=
parser
.
parse_args
()
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
)
benchmarks/kernels/benchmark_moe_defaults.py
0 → 100644
View file @
fbeb8a6f
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark comparing old vs new default fused MoE configs.
Runs the triton fused_moe kernel with three configurations for each scenario:
1. Tuned config (from JSON file, if available) — the target to match
2. Old default (the hardcoded defaults before this change)
3. New default (the improved defaults)
Usage:
python benchmarks/kernels/benchmark_moe_defaults.py
Produces a table showing kernel time (us) and speedup of new vs old defaults.
"""
import
torch
from
vllm.model_executor.layers.fused_moe
import
fused_topk
,
override_config
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
get_default_config
,
get_moe_configs
,
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
from
vllm.utils.torch_utils
import
set_random_seed
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
def
old_default_config
(
M
,
E
,
N
,
K
,
topk
,
dtype
=
None
,
block_shape
=
None
):
"""The original defaults before https://github.com/vllm-project/vllm/pull/34846,
for comparison."""
if
dtype
==
"fp8_w8a8"
and
block_shape
is
not
None
:
return
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
block_shape
[
0
],
"BLOCK_SIZE_K"
:
block_shape
[
1
],
"GROUP_SIZE_M"
:
32
,
"SPLIT_K"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
if
not
current_platform
.
is_rocm
()
else
2
,
}
elif
M
<=
E
:
return
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"SPLIT_K"
:
1
,
}
else
:
return
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"SPLIT_K"
:
1
,
}
def
benchmark_config
(
config
,
M
,
E
,
N
,
K
,
topk
,
dtype
,
use_fp8
=
False
,
block_shape
=
None
,
num_iters
=
100
,
):
"""Time a single kernel config. Returns kernel time in microseconds."""
init_dtype
=
torch
.
float16
if
use_fp8
else
dtype
a
=
torch
.
randn
(
M
,
K
,
device
=
"cuda"
,
dtype
=
init_dtype
)
/
10
w1
=
torch
.
randn
(
E
,
2
*
N
,
K
,
device
=
"cuda"
,
dtype
=
init_dtype
)
/
10
w2
=
torch
.
randn
(
E
,
K
,
N
,
device
=
"cuda"
,
dtype
=
init_dtype
)
/
10
w1_scale
=
None
w2_scale
=
None
a1_scale
=
None
a2_scale
=
None
if
use_fp8
:
if
block_shape
is
not
None
:
bsn
,
bsk
=
block_shape
n_tiles_w1
=
triton
.
cdiv
(
2
*
N
,
bsn
)
k_tiles_w1
=
triton
.
cdiv
(
K
,
bsk
)
n_tiles_w2
=
triton
.
cdiv
(
K
,
bsn
)
k_tiles_w2
=
triton
.
cdiv
(
N
,
bsk
)
w1_scale
=
torch
.
rand
(
E
,
n_tiles_w1
,
k_tiles_w1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
rand
(
E
,
n_tiles_w2
,
k_tiles_w2
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
else
:
w1_scale
=
torch
.
rand
(
E
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
rand
(
E
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
rand
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
rand
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# Only weights are stored in fp8; activations stay in bf16/fp16
# and get dynamically quantized inside the kernel.
w1
=
w1
.
to
(
FP8_DTYPE
)
w2
=
w2
.
to
(
FP8_DTYPE
)
quant_config
=
FusedMoEQuantConfig
.
make
(
quant_dtype
=
torch
.
float8_e4m3fn
if
use_fp8
else
None
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
)
gating
=
torch
.
randn
(
M
,
E
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# Warmup
for
_
in
range
(
20
):
with
override_config
(
config
):
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
gating
,
topk
,
renormalize
=
True
)
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
quant_config
=
quant_config
,
)
torch
.
cuda
.
synchronize
()
# Benchmark
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start
.
record
()
for
_
in
range
(
num_iters
):
with
override_config
(
config
):
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
gating
,
topk
,
renormalize
=
True
)
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
quant_config
=
quant_config
,
)
end
.
record
()
torch
.
cuda
.
synchronize
()
return
start
.
elapsed_time
(
end
)
/
num_iters
*
1000
# ms -> us
# Model configurations: (name, E, N, K, topk, dtype_str, use_fp8, block_shape)
# N = moe_intermediate_size // tp_size (the value used in config file lookup)
MODELS
=
[
# --- Few experts ---
(
"Mixtral bf16"
,
8
,
7168
,
4096
,
2
,
None
,
False
,
None
),
(
"Mixtral fp8"
,
8
,
7168
,
4096
,
2
,
"fp8_w8a8"
,
True
,
None
),
# --- Many experts: real model shapes at tp=1 ---
# Qwen2-MoE-57B: E=60, topk=4, N=1408, K=2048
(
"Qwen2-MoE bf16"
,
60
,
1408
,
2048
,
4
,
None
,
False
,
None
),
# DeepSeek-V2: E=64, topk=6, N=1407, K=4096
# (use 1408 to avoid odd alignment; real model is 1407)
(
"DeepSeek-V2 bf16"
,
64
,
1408
,
4096
,
6
,
None
,
False
,
None
),
# OLMoE-7B: E=64, topk=8, N=2048, K=2048
(
"OLMoE bf16"
,
64
,
2048
,
2048
,
8
,
None
,
False
,
None
),
# GLM-4-100B-A10B: E=128, topk=8, N=1408, K=4096
(
"GLM-4-MoE bf16"
,
128
,
1408
,
4096
,
8
,
None
,
False
,
None
),
# Qwen3-30B-A3B: E=128, topk=8, N=768, K=2048
(
"Qwen3-MoE bf16"
,
128
,
768
,
2048
,
8
,
None
,
False
,
None
),
# DeepSeek-V3 / MiMo-V2-Flash: E=256, topk=8, N=2048, K=7168
(
"DeepSeek-V3 bf16"
,
256
,
2048
,
7168
,
8
,
None
,
False
,
None
),
# Qwen3.5-70B-A22B (Qwen3-Next): E=512, topk=10, N=512, K=2048
(
"Qwen3-Next bf16"
,
512
,
512
,
2048
,
10
,
None
,
False
,
None
),
# E=128 N=1856 bf16
(
"E128 N1856 bf16"
,
128
,
1856
,
4096
,
8
,
None
,
False
,
None
),
# E=256 N=512 bf16 (DS-V3 tp=4)
(
"DS-V3 tp4 bf16"
,
256
,
512
,
7168
,
8
,
None
,
False
,
None
),
# E=512 N=512 bf16 (Qwen3-Next tp=1)
(
"Qwen3-Next bf16"
,
512
,
512
,
2048
,
10
,
None
,
False
,
None
),
# E=512 N=256 bf16 (Qwen3-Next tp=2)
(
"Qwen3-Next tp2"
,
512
,
256
,
2048
,
10
,
None
,
False
,
None
),
# --- FP8 block quant (many experts) ---
# DS-V3 tp=4: E=256, N=512, fp8 block
(
"DS-V3 tp4 fp8blk"
,
256
,
512
,
7168
,
8
,
"fp8_w8a8"
,
True
,
[
128
,
128
]),
# DS-V3 tp=8: E=256, N=256, fp8 block
(
"DS-V3 tp8 fp8blk"
,
256
,
256
,
7168
,
8
,
"fp8_w8a8"
,
True
,
[
128
,
128
]),
# Qwen3-Next tp=2 fp8 block
(
"Qwen3-Next tp2 fp8blk"
,
512
,
256
,
2048
,
10
,
"fp8_w8a8"
,
True
,
[
128
,
128
]),
]
BATCH_SIZES
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
def
main
():
set_random_seed
(
0
)
torch
.
set_default_device
(
"cuda"
)
dtype
=
torch
.
bfloat16
for
name
,
E
,
N
,
K
,
topk
,
dtype_str
,
use_fp8
,
block_shape
in
MODELS
:
print
(
f
"
\n
{
'='
*
90
}
"
)
print
(
f
"
{
name
}
(E=
{
E
}
, N=
{
N
}
, K=
{
K
}
, topk=
{
topk
}
)"
)
print
(
f
"
{
'='
*
90
}
"
)
# Try to load tuned config
block_n
=
block_shape
[
0
]
if
block_shape
else
None
block_k
=
block_shape
[
1
]
if
block_shape
else
None
tuned
=
get_moe_configs
(
E
,
N
,
dtype_str
,
block_n
,
block_k
)
has_tuned
=
tuned
is
not
None
print
(
f
" Tuned config available:
{
has_tuned
}
"
)
hdr
=
(
f
"
{
'Batch'
:
>
6
}
|
{
'Tuned (us)'
:
>
11
}
|
{
'Old (us)'
:
>
11
}
| "
f
"
{
'New (us)'
:
>
11
}
|
{
'New/Old'
:
>
8
}
|
{
'New/Tuned'
:
>
10
}
"
)
print
(
f
"
{
hdr
}
"
)
print
(
f
"
{
'-'
*
len
(
hdr
)
}
"
)
for
M
in
BATCH_SIZES
:
old_cfg
=
old_default_config
(
M
,
E
,
N
,
K
,
topk
,
dtype_str
,
block_shape
)
new_cfg
=
get_default_config
(
M
,
E
,
N
,
K
,
topk
,
dtype_str
,
block_shape
)
if
has_tuned
:
tuned_cfg
=
tuned
[
min
(
tuned
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
t_tuned
=
benchmark_config
(
tuned_cfg
,
M
,
E
,
N
,
K
,
topk
,
dtype
,
use_fp8
=
use_fp8
,
block_shape
=
block_shape
,
)
else
:
t_tuned
=
None
t_old
=
benchmark_config
(
old_cfg
,
M
,
E
,
N
,
K
,
topk
,
dtype
,
use_fp8
=
use_fp8
,
block_shape
=
block_shape
,
)
t_new
=
benchmark_config
(
new_cfg
,
M
,
E
,
N
,
K
,
topk
,
dtype
,
use_fp8
=
use_fp8
,
block_shape
=
block_shape
,
)
ratio_new_old
=
t_new
/
t_old
tuned_str
=
f
"
{
t_tuned
:
11.2
f
}
"
if
t_tuned
else
f
"
{
'N/A'
:
>
11
}
"
ratio_tuned
=
f
"
{
t_new
/
t_tuned
:
10.2
f
}
x"
if
t_tuned
else
f
"
{
'N/A'
:
>
10
}
"
# flag regressions where new default is >5% slower than old
marker
=
" <--"
if
ratio_new_old
>
1.05
else
""
print
(
f
"
{
M
:
>
6
}
|
{
tuned_str
}
|
{
t_old
:
11.2
f
}
|
{
t_new
:
11.2
f
}
"
f
"|
{
ratio_new_old
:
7.2
f
}
x |
{
ratio_tuned
}{
marker
}
"
)
if
__name__
==
"__main__"
:
main
()
benchmarks/kernels/benchmark_moe_permute_unpermute.py
0 → 100644
View file @
fbeb8a6f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
from
typing
import
Any
,
TypedDict
import
ray
import
torch
from
transformers
import
AutoConfig
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
moe_permute
,
moe_unpermute
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.platforms
import
current_platform
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
set_random_seed
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
class
BenchmarkConfig
(
TypedDict
):
BLOCK_SIZE_M
:
int
BLOCK_SIZE_N
:
int
BLOCK_SIZE_K
:
int
GROUP_SIZE_M
:
int
num_warps
:
int
num_stages
:
int
def
benchmark_permute
(
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
# output_hidden_states = torch.empty_like(hidden_states)
if
use_fp8_w8a8
:
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
qhidden_states
=
hidden_states
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
qhidden_states
,
input_gating
,
topk
,
False
)
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
moe_permute
(
qhidden_states
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
expert_map
=
None
,
)
# JIT compilation & warmup
run
()
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
()
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
prepare
(
i
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
def
benchmark_unpermute
(
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
if
use_fp8_w8a8
:
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
qhidden_states
=
hidden_states
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
qhidden_states
,
input_gating
,
topk
,
False
)
def
prepare
():
(
permuted_hidden_states
,
_
,
first_token_off
,
inv_perm_idx
,
_
,
)
=
moe_permute
(
qhidden_states
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
expert_map
=
None
,
)
# convert to fp16/bf16 as gemm output
return
(
permuted_hidden_states
.
to
(
dtype
),
first_token_off
,
inv_perm_idx
,
)
def
run
(
input
:
tuple
):
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
)
=
input
output
=
torch
.
empty_like
(
hidden_states
)
moe_unpermute
(
output
,
permuted_hidden_states
,
topk_weights
,
inv_perm_idx
,
first_token_off
,
)
# JIT compilation & warmup
input
=
prepare
()
run
(
input
)
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
(
input
)
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
set_random_seed
(
seed
)
self
.
seed
=
seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self
.
device_id
=
int
(
ray
.
get_gpu_ids
()[
0
])
def
benchmark
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
)
->
tuple
[
float
,
float
]:
set_random_seed
(
self
.
seed
)
permute_time
=
benchmark_permute
(
num_tokens
,
num_experts
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
)
unpermute_time
=
benchmark_unpermute
(
num_tokens
,
num_experts
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
)
return
permute_time
,
unpermute_time
def
get_weight_block_size_safety
(
config
,
default_value
=
None
):
quantization_config
=
getattr
(
config
,
"quantization_config"
,
{})
if
isinstance
(
quantization_config
,
dict
):
return
quantization_config
.
get
(
"weight_block_size"
,
default_value
)
return
default_value
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
elif
(
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"Glm4MoeForCausalLM"
or
config
.
architectures
[
0
]
==
"Glm4MoeLiteForCausalLM"
):
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
elif
config
.
architectures
[
0
]
in
[
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
]:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
else
:
# Support for llama4
config
=
config
.
get_text_config
()
# Default: Mixtral.
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
if
args
.
batch_size
is
None
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
,
]
else
:
batch_sizes
=
[
args
.
batch_size
]
ray
.
init
()
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
def
_distribute
(
method
:
str
,
inputs
:
list
[
Any
])
->
list
[
Any
]:
outputs
=
[]
worker_idx
=
0
for
input_args
in
inputs
:
worker
=
workers
[
worker_idx
]
worker_method
=
getattr
(
worker
,
method
)
output
=
worker_method
.
remote
(
*
input_args
)
outputs
.
append
(
output
)
worker_idx
=
(
worker_idx
+
1
)
%
num_gpus
return
ray
.
get
(
outputs
)
outputs
=
_distribute
(
"benchmark"
,
[
(
batch_size
,
E
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
)
for
batch_size
in
batch_sizes
],
)
for
batch_size
,
(
permute
,
unpermute
)
in
zip
(
batch_sizes
,
outputs
):
print
(
f
"Batch size:
{
batch_size
}
"
)
print
(
f
"Permute time:
{
permute
:.
2
f
}
us"
)
print
(
f
"Unpermute time:
{
unpermute
:.
2
f
}
us"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
main
(
args
)
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment