Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tsoc
vllm-auto-test
Commits
d1a06223
Commit
d1a06223
authored
Feb 24, 2026
by
liuxu3
Browse files
added vllm092 auto test scripts
parent
fba2e3b5
Changes
162
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6521 additions
and
0 deletions
+6521
-0
online_apiserver_test_maxbs/benchmarks/disagg_benchmarks/visualize_benchmark_results.py
...nchmarks/disagg_benchmarks/visualize_benchmark_results.py
+47
-0
online_apiserver_test_maxbs/benchmarks/fused_kernels/layernorm_rms_benchmarks.py
...axbs/benchmarks/fused_kernels/layernorm_rms_benchmarks.py
+228
-0
online_apiserver_test_maxbs/benchmarks/kernels/bench_fp8_gemm.py
...apiserver_test_maxbs/benchmarks/kernels/bench_fp8_gemm.py
+159
-0
online_apiserver_test_maxbs/benchmarks/kernels/bench_int8_gemm.py
...piserver_test_maxbs/benchmarks/kernels/bench_int8_gemm.py
+169
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_aqlm.py
...apiserver_test_maxbs/benchmarks/kernels/benchmark_aqlm.py
+345
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_bitblas.py
...server_test_maxbs/benchmarks/kernels/benchmark_bitblas.py
+242
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_cutlass_fp4_moe.py
...est_maxbs/benchmarks/kernels/benchmark_cutlass_fp4_moe.py
+490
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
...axbs/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+383
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_layernorm.py
...rver_test_maxbs/benchmarks/kernels/benchmark_layernorm.py
+93
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_lora.py
...apiserver_test_maxbs/benchmarks/kernels/benchmark_lora.py
+1065
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_machete.py
...server_test_maxbs/benchmarks/kernels/benchmark_machete.py
+732
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_marlin.py
...iserver_test_maxbs/benchmarks/kernels/benchmark_marlin.py
+413
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_moe.py
..._apiserver_test_maxbs/benchmarks/kernels/benchmark_moe.py
+737
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_moe_align_block_size.py
...axbs/benchmarks/kernels/benchmark_moe_align_block_size.py
+159
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_moe_permute_unpermute.py
...xbs/benchmarks/kernels/benchmark_moe_permute_unpermute.py
+417
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_paged_attention.py
...est_maxbs/benchmarks/kernels/benchmark_paged_attention.py
+251
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_quant.py
...piserver_test_maxbs/benchmarks/kernels/benchmark_quant.py
+108
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_rmsnorm.py
...server_test_maxbs/benchmarks/kernels/benchmark_rmsnorm.py
+256
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_rope.py
...apiserver_test_maxbs/benchmarks/kernels/benchmark_rope.py
+133
-0
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_shapes.py
...iserver_test_maxbs/benchmarks/kernels/benchmark_shapes.py
+94
-0
No files found.
Too many changes to show.
To preserve performance only
162 of 162+
files are displayed.
Plain diff
Email patch
online_apiserver_test_maxbs/benchmarks/disagg_benchmarks/visualize_benchmark_results.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
matplotlib.pyplot
as
plt
import
pandas
as
pd
if
__name__
==
"__main__"
:
data
=
[]
for
name
in
[
"disagg_prefill"
,
"chunked_prefill"
]:
for
qps
in
[
2
,
4
,
6
,
8
]:
with
open
(
f
"results/
{
name
}
-qps-
{
qps
}
.json"
)
as
f
:
x
=
json
.
load
(
f
)
x
[
"name"
]
=
name
x
[
"qps"
]
=
qps
data
.
append
(
x
)
df
=
pd
.
DataFrame
.
from_dict
(
data
)
dis_df
=
df
[
df
[
"name"
]
==
"disagg_prefill"
]
chu_df
=
df
[
df
[
"name"
]
==
"chunked_prefill"
]
plt
.
style
.
use
(
"bmh"
)
plt
.
rcParams
[
"font.size"
]
=
20
for
key
in
[
"mean_ttft_ms"
,
"median_ttft_ms"
,
"p99_ttft_ms"
,
"mean_itl_ms"
,
"median_itl_ms"
,
"p99_itl_ms"
,
]:
fig
,
ax
=
plt
.
subplots
(
figsize
=
(
11
,
7
))
plt
.
plot
(
dis_df
[
"qps"
],
dis_df
[
key
],
label
=
"disagg_prefill"
,
marker
=
"o"
,
linewidth
=
4
)
plt
.
plot
(
chu_df
[
"qps"
],
chu_df
[
key
],
label
=
"chunked_prefill"
,
marker
=
"o"
,
linewidth
=
4
)
ax
.
legend
()
ax
.
set_xlabel
(
"QPS"
)
ax
.
set_ylabel
(
key
)
ax
.
set_ylim
(
bottom
=
0
)
fig
.
savefig
(
f
"results/
{
key
}
.png"
)
plt
.
close
(
fig
)
online_apiserver_test_maxbs/benchmarks/fused_kernels/layernorm_rms_benchmarks.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pickle
as
pkl
import
time
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
itertools
import
product
from
typing
import
Callable
,
Optional
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
tqdm
import
tqdm
import
vllm._custom_ops
as
ops
from
vllm.model_executor.layers.layernorm
import
RMSNorm
@
dataclass
class
bench_params_t
:
num_tokens
:
int
hidden_size
:
int
add_residual
:
bool
dtype
:
torch
.
dtype
def
description
(
self
):
return
(
f
"N
{
self
.
num_tokens
}
"
f
"x D
{
self
.
hidden_size
}
"
f
"x R
{
self
.
add_residual
}
"
f
"x DT
{
self
.
dtype
}
"
)
def
get_bench_params
()
->
list
[
bench_params_t
]:
## Test Fixtures
NUM_TOKENS
=
[
2
**
x
for
x
in
range
(
11
)]
HIDDEN_SIZES
=
list
(
range
(
1024
,
8129
,
1024
))
ADD_RESIDUAL
=
[
True
,
False
]
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float
]
combinations
=
product
(
NUM_TOKENS
,
HIDDEN_SIZES
,
ADD_RESIDUAL
,
DTYPES
)
bench_params
=
list
(
map
(
lambda
x
:
bench_params_t
(
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
]),
combinations
)
)
return
bench_params
# Reference impls
def
unfused_int8_impl
(
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
quant_dtype
:
torch
.
dtype
,
):
# Norm
torch_out
=
None
if
residual
is
None
:
torch_out
=
rms_norm_layer
.
forward_cuda
(
x
,
residual
)
else
:
torch_out
,
_
=
rms_norm_layer
.
forward_cuda
(
x
,
residual
)
# Quant
torch_out
,
_
,
_
=
ops
.
scaled_int8_quant
(
torch_out
)
def
unfused_fp8_impl
(
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
quant_dtype
:
torch
.
dtype
,
):
# Norm
torch_out
=
None
if
residual
is
None
:
torch_out
=
rms_norm_layer
.
forward_cuda
(
x
,
residual
)
else
:
torch_out
,
_
=
rms_norm_layer
.
forward_cuda
(
x
,
residual
)
# Quant
torch_out
,
_
=
ops
.
scaled_fp8_quant
(
torch_out
)
def
fused_impl
(
rms_norm_layer
:
RMSNorm
,
# this stores the weights
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
quant_dtype
:
torch
.
dtype
,
):
out
,
_
=
ops
.
rms_norm_dynamic_per_token_quant
(
x
,
rms_norm_layer
.
weight
,
1e-6
,
quant_dtype
,
residual
=
residual
)
# Bench functions
def
bench_fn
(
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
label
:
str
,
sub_label
:
str
,
fn
:
Callable
,
description
:
str
,
)
->
TMeasurement
:
min_run_time
=
1
globals
=
{
"rms_norm_layer"
:
rms_norm_layer
,
"x"
:
x
,
"residual"
:
residual
,
"quant_dtype"
:
quant_dtype
,
"fn"
:
fn
,
}
return
TBenchmark
.
Timer
(
stmt
=
"fn(rms_norm_layer, x, residual, quant_dtype)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
description
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
def
bench
(
params
:
bench_params_t
,
label
:
str
,
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
# Make inputs
layer
=
RMSNorm
(
params
.
hidden_size
,
1e-6
).
to
(
dtype
=
params
.
dtype
)
# Make weights
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
# Make inputs
scale
=
1
/
params
.
hidden_size
x
=
(
torch
.
randn
(
params
.
num_tokens
,
params
.
hidden_size
,
dtype
=
params
.
dtype
,
device
=
"cuda"
)
*
scale
)
residual
=
(
(
torch
.
randn_like
(
x
)
*
scale
).
to
(
device
=
"cuda"
)
if
params
.
add_residual
else
None
)
timers
=
[]
# unfused int8 impl.
timers
.
append
(
bench_fn
(
layer
,
x
,
residual
,
torch
.
int8
,
label
,
sub_label
,
unfused_int8_impl
,
"unfused_int8_impl"
,
)
)
# unfused fp8 impl.
timers
.
append
(
bench_fn
(
layer
,
x
,
residual
,
torch
.
float8_e4m3fn
,
label
,
sub_label
,
unfused_fp8_impl
,
"unfused_fp8_impl"
,
)
)
# fused int8 impl.
timers
.
append
(
bench_fn
(
layer
,
x
,
residual
,
torch
.
int8
,
label
,
sub_label
,
fused_impl
,
"fused_int8_impl"
,
)
)
# fused fp8 impl.
timers
.
append
(
bench_fn
(
layer
,
x
,
residual
,
torch
.
float8_e4m3fn
,
label
,
sub_label
,
fused_impl
,
"fused_fp8_impl"
,
)
)
print_timers
(
timers
)
return
timers
# launch bench
# runner
def
print_timers
(
timers
:
Iterable
[
TMeasurement
]):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
def
main
():
torch
.
set_default_device
(
"cuda"
)
bench_params
=
get_bench_params
()
timers
=
[]
for
bp
in
tqdm
(
bench_params
):
timers
.
extend
(
bench
(
bp
,
"rms-norm-dynamic-per-token-quant"
,
bp
.
description
()))
print_timers
(
timers
)
# pickle all the results
timestamp
=
int
(
time
.
time
())
with
open
(
f
"rms_norm_dpt_quant-
{
timestamp
}
.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
(
timers
,
f
)
if
__name__
==
"__main__"
:
main
()
online_apiserver_test_maxbs/benchmarks/kernels/bench_fp8_gemm.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_fp8_quant
as
vllm_scaled_fp8_quant
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"fp8-tensor-w-token-a"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
False
),
"fp8-tensor-w-tensor-a"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
True
),
"fp8-channel-w-token-a"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
True
),
"fp8-channel-w-tensor-a"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
False
),
"fp8-tensor-w-token-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
False
),
"fp8-tensor-w-tensor-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
True
),
"fp8-channel-w-token-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
True
),
"fp8-channel-w-tensor-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
False
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
_quant_weight_fp8
(
b
:
torch
.
Tensor
,
w_type
:
str
,
device
:
str
):
if
w_type
==
"tensor"
:
scale_b
=
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
else
:
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
use_per_token_if_dynamic
=
True
)
return
b_fp8
.
t
(),
scale_b_fp8
def
build_fp8_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
b_fp8
,
scale_b_fp8
=
_quant_weight_fp8
(
b
,
cfg
[
"w"
],
device
)
scale_a_const
=
(
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
if
cfg
[
"a"
]
==
"tensor"
else
None
)
if
cfg
[
"no_a_quant"
]:
if
cfg
[
"a"
]
==
"tensor"
:
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a_const
)
else
:
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
def
run
():
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
return
run
if
cfg
[
"a"
]
==
"tensor"
:
def
run
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a_const
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
else
:
def
run
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs FP8 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_fp8_runner
(
cfg
,
a
,
b
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs FP8 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_fp8_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
online_apiserver_test_maxbs/benchmarks/kernels/bench_int8_gemm.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_int8_quant
as
vllm_scaled_int8_quant
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"int8-tensor-w-token-a"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
False
),
"int8-tensor-w-tensor-a"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
True
),
"int8-channel-w-token-a"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
True
),
"int8-channel-w-tensor-a"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
False
),
"int8-tensor-w-token-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
False
),
"int8-tensor-w-tensor-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
True
),
"int8-channel-w-token-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
True
),
"int8-channel-w-tensor-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
False
),
}
def
_quant_weight
(
b
,
w_type
,
device
):
if
w_type
==
"tensor"
:
scale_b
=
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
,
scale_b
)
assert
scale_b_int8
.
numel
()
==
1
else
:
# channel
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
)
assert
scale_b_int8
.
numel
()
==
b
.
shape
[
0
]
return
b_int8
.
t
(),
scale_b_int8
def
build_int8_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
# quant before running the kernel
b_int8
,
scale_b_int8
=
_quant_weight
(
b
,
cfg
[
"w"
],
device
)
scale_a_const
=
None
if
cfg
[
"a"
]
==
"tensor"
:
scale_a_const
=
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
# no quant, create activation ahead
if
cfg
[
"no_a_quant"
]:
if
cfg
[
"a"
]
==
"tensor"
:
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a_const
)
else
:
# token
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
def
run_quant
():
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
return
run_quant
# dynamic quant, create activation inside
if
cfg
[
"a"
]
==
"tensor"
:
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a_const
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
else
:
# token
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
return
run_quant
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
.
get
(
"enabled"
)]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
[
k
for
k
in
_enabled
],
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs INT8 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_int8_runner
(
cfg
,
a
,
b
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
KN_model_names
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
KN_model_names
.
append
(
KN
)
return
KN_model_names
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
)
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs INT8 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_int8_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_aqlm.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
sys
from
typing
import
Optional
import
torch
import
torch.nn.functional
as
F
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.aqlm
import
(
dequantize_weight
,
generic_dequantize_gemm
,
get_int_dtype
,
optimized_dequantize_gemm
,
)
from
vllm.utils
import
FlexibleArgumentParser
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0"
def
torch_mult
(
# [..., in_features]
input
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
scales
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
output
=
F
.
linear
(
input
,
weights
)
return
output
def
dequant_out_scale
(
# [..., in_features]
input
:
torch
.
Tensor
,
# [num_out_groups, num_in_groups, num_codebooks]
codes
:
torch
.
IntTensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
scales
:
torch
.
Tensor
,
output_partition_sizes
:
torch
.
IntTensor
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
weights
=
ops
.
aqlm_dequant
(
codes
,
codebooks
,
output_partition_sizes
)
if
bias
is
None
:
output
=
F
.
linear
(
input
,
weights
,
bias
)
orig_shape
=
output
.
shape
flattened_output
=
output
.
view
(
-
1
,
output
.
size
(
-
1
))
f_scales
=
scales
.
view
(
-
1
,
scales
.
shape
[
0
])
b_scales
=
f_scales
.
expand
(
flattened_output
.
shape
[
0
],
-
1
)
flattened_output
*=
b_scales
return
flattened_output
.
view
(
orig_shape
)
else
:
b_scales
=
scales
.
view
(
scales
.
shape
[:
-
3
]
+
(
-
1
,)).
expand
(
-
1
,
weights
.
shape
[
1
])
weights
*=
b_scales
return
F
.
linear
(
input
,
weights
,
bias
)
def
dequant_weight_scale
(
# [..., in_features]
input
:
torch
.
Tensor
,
# [num_out_groups, num_in_groups, num_codebooks]
codes
:
torch
.
IntTensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
scales
:
torch
.
Tensor
,
output_partition_sizes
:
torch
.
IntTensor
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
weights
=
ops
.
aqlm_dequant
(
codes
,
codebooks
,
output_partition_sizes
)
b_scales
=
scales
.
view
(
scales
.
shape
[:
-
3
]
+
(
-
1
,)).
expand
(
-
1
,
weights
.
shape
[
1
])
weights
*=
b_scales
return
F
.
linear
(
input
,
weights
,
bias
)
def
dequant_no_scale
(
# [..., in_features]
input
:
torch
.
Tensor
,
# [num_out_groups, num_in_groups, num_codebooks]
codes
:
torch
.
IntTensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
scales
:
torch
.
Tensor
,
output_partition_sizes
:
torch
.
IntTensor
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
weights
=
ops
.
aqlm_dequant
(
codes
,
codebooks
,
output_partition_sizes
)
return
F
.
linear
(
input
,
weights
,
bias
)
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
# the generic pytorch version.
# Just visual comparison.
def
dequant_test
(
k
:
int
,
parts
:
torch
.
Tensor
,
nbooks
:
int
,
bits
:
int
)
->
None
:
n
=
int
(
parts
.
sum
().
item
())
device
=
torch
.
device
(
"cuda:0"
)
code_range
=
(
1
<<
bits
)
//
2
ingroups
=
8
codes
=
torch
.
randint
(
-
code_range
,
code_range
,
size
=
(
n
,
k
//
ingroups
,
nbooks
),
dtype
=
get_int_dtype
(
bits
),
device
=
device
,
)
codebooks
=
torch
.
randn
(
size
=
(
parts
.
shape
[
0
]
*
nbooks
,
1
<<
bits
,
1
,
8
),
dtype
=
torch
.
float16
,
device
=
device
,
)
count
=
0
for
index
in
range
(
16
):
for
i
in
range
(
8
):
for
book
in
range
(
nbooks
):
codebooks
[
book
,
index
,
0
,
i
]
=
count
*
(
10
**
book
)
count
+=
1
print
(
"codes shape"
,
codes
.
shape
)
for
i
in
range
(
16
):
for
book
in
range
(
nbooks
):
codes
[
0
,
i
,
book
]
=
i
codes
[
0
,
-
i
,
book
]
=
i
weights
=
dequantize_weight
(
codes
,
codebooks
,
None
)
weights2
=
ops
.
aqlm_dequant
(
codes
,
codebooks
,
parts
)
print
(
"weights shape:"
,
weights
.
shape
)
print
(
"weights2 shape:"
,
weights2
.
shape
)
print
(
"weights are:"
,
weights
)
print
(
"weights2 are:"
,
weights2
)
print
(
"first 128 weights are"
,
weights
[
0
,
0
:
128
].
to
(
torch
.
int32
))
print
(
"first 128 weights2 are:"
,
weights2
[
0
,
0
:
128
].
to
(
torch
.
int32
))
print
(
"last 128 weights are"
,
weights
[
0
,
-
128
:])
print
(
"last 128 weights2 are:"
,
weights2
[
0
,
-
128
:])
def
main
():
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark aqlm performance."
)
# Add arguments
parser
.
add_argument
(
"--nbooks"
,
type
=
int
,
default
=
1
,
help
=
"Number of codebooks (default: 1)"
)
parser
.
add_argument
(
"--bits"
,
type
=
int
,
default
=
16
,
help
=
"Number of bits per code element (default: 16)"
,
)
parser
.
add_argument
(
"--test"
,
type
=
bool
,
default
=
False
,
help
=
"Run the decompression/dequant tester rather than benchmarking "
"(default: False)"
,
)
# Parse the arguments
args
=
parser
.
parse_args
()
# Extract values
nbooks
=
args
.
nbooks
bits
=
args
.
bits
if
args
.
test
:
dequant_test
(
4096
,
torch
.
tensor
((
4096
,)),
nbooks
,
bits
)
return
# Otherwise, benchmark.
methods
=
[
ops
.
aqlm_gemm
,
dequant_out_scale
,
generic_dequantize_gemm
,
optimized_dequantize_gemm
,
dequant_weight_scale
,
torch_mult
,
dequant_no_scale
,
]
filename
=
f
"./aqlm_benchmark_
{
nbooks
}
x
{
bits
}
.csv"
print
(
f
"writing benchmarks to file
{
filename
}
"
)
with
open
(
filename
,
"w"
)
as
f
:
sys
.
stdout
=
f
print
(
"m | k | n | n parts"
,
end
=
""
)
for
method
in
methods
:
print
(
f
" |
{
method
.
__name__
.
replace
(
'_'
,
' '
)
}
(µs)"
,
end
=
""
)
print
(
""
)
# These are reasonable prefill sizes.
ksandpartions
=
(
(
4096
,
(
4096
,
4096
,
4096
)),
(
4096
,
(
4096
,)),
(
4096
,
(
11008
,
11008
)),
(
11008
,
(
4096
,)),
)
# reasonable ranges for m.
for
m
in
[
1
,
2
,
4
,
8
,
10
,
12
,
14
,
16
,
24
,
32
,
48
,
52
,
56
,
64
,
96
,
112
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
,
]:
print
(
f
"
{
m
}
"
,
file
=
sys
.
__stdout__
)
for
ksp
in
ksandpartions
:
run_grid
(
m
,
ksp
[
0
],
torch
.
tensor
(
ksp
[
1
]),
nbooks
,
bits
,
methods
)
sys
.
stdout
=
sys
.
__stdout__
def
run_grid
(
m
:
int
,
k
:
int
,
parts
:
torch
.
Tensor
,
nbooks
:
int
,
bits
:
int
,
methods
):
# I didn't see visible improvements from increasing these, but feel free :)
num_warmup_trials
=
1
num_trials
=
1
num_calls
=
100
# warmup.
for
method
in
methods
:
for
_
in
range
(
num_warmup_trials
):
run_timing
(
num_calls
=
num_calls
,
m
=
m
,
k
=
k
,
parts
=
parts
,
nbooks
=
nbooks
,
bits
=
bits
,
method
=
method
,
)
n
=
parts
.
sum
().
item
()
print
(
f
"
{
m
}
|
{
k
}
|
{
n
}
|
{
parts
.
tolist
()
}
"
,
end
=
""
)
for
method
in
methods
:
best_time_us
=
1e20
for
_
in
range
(
num_trials
):
kernel_dur_ms
=
run_timing
(
num_calls
=
num_calls
,
m
=
m
,
k
=
k
,
parts
=
parts
,
nbooks
=
nbooks
,
bits
=
bits
,
method
=
method
,
)
kernel_dur_us
=
1000
*
kernel_dur_ms
if
kernel_dur_us
<
best_time_us
:
best_time_us
=
kernel_dur_us
print
(
f
" |
{
kernel_dur_us
:.
0
f
}
"
,
end
=
""
)
print
(
""
)
def
run_timing
(
num_calls
:
int
,
m
:
int
,
k
:
int
,
parts
:
torch
.
Tensor
,
nbooks
:
int
,
bits
:
int
,
method
)
->
float
:
n
=
int
(
parts
.
sum
().
item
())
device
=
torch
.
device
(
"cuda:0"
)
input
=
torch
.
randn
((
1
,
m
,
k
),
dtype
=
torch
.
float16
,
device
=
device
)
code_range
=
(
1
<<
bits
)
//
2
ingroups
=
8
codes
=
torch
.
randint
(
-
code_range
,
code_range
,
size
=
(
n
,
k
//
ingroups
,
nbooks
),
dtype
=
get_int_dtype
(
bits
),
device
=
device
,
)
codebooks
=
torch
.
randn
(
size
=
(
parts
.
shape
[
0
]
*
nbooks
,
1
<<
bits
,
1
,
8
),
dtype
=
torch
.
float16
,
device
=
device
,
)
scales
=
torch
.
randn
(
size
=
(
n
,
1
,
1
,
1
),
dtype
=
torch
.
float16
,
device
=
device
)
# for comparison to just a pytorch mult.
weights
=
torch
.
randn
((
n
,
k
),
dtype
=
torch
.
float16
,
device
=
device
)
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
.
record
()
if
method
is
torch_mult
:
for
i
in
range
(
num_calls
):
torch_mult
(
input
,
weights
,
scales
)
else
:
for
i
in
range
(
num_calls
):
method
(
input
,
codes
,
codebooks
,
scales
,
parts
,
None
)
end_event
.
record
()
end_event
.
synchronize
()
dur_ms
=
start_event
.
elapsed_time
(
end_event
)
/
num_calls
return
dur_ms
if
__name__
==
"__main__"
:
sys
.
exit
(
main
())
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_bitblas.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from
vllm.model_executor.layers.quantization.utils.bitblas_utils
import
(
MINIMUM_BITBLAS_VERSION
,
)
try
:
import
bitblas
if
bitblas
.
__version__
<
MINIMUM_BITBLAS_VERSION
:
raise
ImportError
(
"bitblas version is wrong. Please "
f
"install bitblas>=
{
MINIMUM_BITBLAS_VERSION
}
"
)
except
ImportError
as
e
:
bitblas_import_exception
=
e
raise
ValueError
(
"Trying to use the bitblas backend, but could not import"
f
"with the following error:
{
bitblas_import_exception
}
. "
"Please install bitblas through the following command: "
f
"`pip install bitblas>=
{
MINIMUM_BITBLAS_VERSION
}
`"
)
from
bitblas_import_exception
from
bitblas
import
Matmul
,
MatmulConfig
,
auto_detect_nvidia_target
from
vllm.utils
import
FlexibleArgumentParser
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark BitBLAS int4 on a specific target."
)
# Add arguments to the parser
parser
.
add_argument
(
"--target"
,
type
=
str
,
default
=
auto_detect_nvidia_target
(),
help
=
"Specify the target device for benchmarking."
,
)
parser
.
add_argument
(
"--group_size"
,
type
=
int
,
default
=
None
,
help
=
"Group size for grouped quantization."
)
parser
.
add_argument
(
"--A_dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int8"
],
help
=
"Data type of activation A."
,
)
parser
.
add_argument
(
"--W_dtype"
,
type
=
str
,
default
=
"int4"
,
choices
=
[
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int8"
,
"int4"
,
"int2"
,
"int1"
,
"nf4"
,
"fp4_e2m1"
,
],
help
=
"Data type of weight W."
,
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"int32"
],
help
=
"Data type for accumulation."
,
)
parser
.
add_argument
(
"--out_dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"float32"
,
"int32"
,
"int8"
],
help
=
"Data type for output."
,
)
parser
.
add_argument
(
"--layout"
,
type
=
str
,
default
=
"nt"
,
choices
=
[
"nt"
,
"nn"
],
help
=
"Matrix layout, 'nt' for non-transpose A and transpose W."
,
)
parser
.
add_argument
(
"--with_bias"
,
action
=
"store_true"
,
help
=
"Include bias in the benchmark."
)
parser
.
add_argument
(
"--with_scaling"
,
action
=
"store_true"
,
help
=
"Include scaling factor in the quantization."
,
)
parser
.
add_argument
(
"--with_zeros"
,
action
=
"store_true"
,
help
=
"Include zeros in the quantization."
)
parser
.
add_argument
(
"--zeros_mode"
,
type
=
str
,
default
=
None
,
choices
=
[
"original"
,
"rescale"
,
"quantized"
],
help
=
"Specify the mode for calculating zeros."
,
)
# Parse the arguments
args
=
parser
.
parse_args
()
# Assign arguments to variables
target
=
args
.
target
A_dtype
=
args
.
A_dtype
W_dtype
=
args
.
W_dtype
accum_dtype
=
args
.
accum_dtype
out_dtype
=
args
.
out_dtype
layout
=
args
.
layout
with_bias
=
args
.
with_bias
group_size
=
args
.
group_size
with_scaling
=
args
.
with_scaling
with_zeros
=
args
.
with_zeros
zeros_mode
=
args
.
zeros_mode
# Define a list of shared arguments that repeat in every config
shared_args
=
[
A_dtype
,
W_dtype
,
out_dtype
,
accum_dtype
,
layout
,
with_bias
,
group_size
,
with_scaling
,
with_zeros
,
zeros_mode
,
]
# Define just the (M, K, N) shapes in a more compact list
shapes
=
[
# square test
(
1
,
16384
,
16384
),
# BLOOM-176B
(
1
,
43008
,
14336
),
(
1
,
14336
,
14336
),
(
1
,
57344
,
14336
),
(
1
,
14336
,
57344
),
# OPT-65B
(
1
,
9216
,
9216
),
(
1
,
36864
,
9216
),
(
1
,
9216
,
36864
),
(
1
,
22016
,
8192
),
# LLAMA-70B/65B
(
1
,
8192
,
22016
),
(
1
,
8192
,
8192
),
(
1
,
28672
,
8192
),
(
1
,
8192
,
28672
),
# square test
(
16384
,
16384
,
16384
),
# BLOOM-176B
(
8192
,
43008
,
14336
),
(
8192
,
14336
,
14336
),
(
8192
,
57344
,
14336
),
(
8192
,
14336
,
57344
),
# OPT-65B
(
8192
,
9216
,
9216
),
(
8192
,
36864
,
9216
),
(
8192
,
9216
,
36864
),
(
8192
,
22016
,
8192
),
# LLAMA-70B/65B
(
8192
,
8192
,
22016
),
(
8192
,
8192
,
8192
),
(
8192
,
28672
,
8192
),
(
8192
,
8192
,
28672
),
]
# Build test shapes with all the shared arguments
test_shapes
=
[(
MatmulConfig
,
Matmul
,
(
*
shape
,
*
shared_args
))
for
shape
in
shapes
]
benchmark_sets
=
[]
benchmark_sets
.
extend
(
test_shapes
)
benchmark_results
=
{}
for
config_class
,
operator
,
input_args
in
benchmark_sets
:
config
=
config_class
(
*
input_args
)
matmul
=
operator
(
config
,
target
=
target
,
enable_tuning
=
True
)
kernel_latency
=
matmul
.
profile_latency
()
print
(
"Time cost is: {:.3f} ms"
.
format
(
kernel_latency
))
profile_config
=
{
f
"
{
operator
.
__name__
}
-
{
'-'
.
join
([
str
(
i
)
for
i
in
input_args
])
}
"
:
{
"BitBLAS_top20_latency"
:
kernel_latency
,
}
}
benchmark_results
.
update
(
profile_config
)
# Define headers for the table
headers
=
[
"PrimFunc"
,
"Input Arguments"
,
"BitBLAS Top20 Latency"
,
]
# Calculate column widths for pretty printing
col_widths
=
[
0
,
0
,
0
]
for
config_key
,
values
in
benchmark_results
.
items
():
args_split
=
config_key
.
split
(
"-"
)
func_name
=
args_split
[
0
]
input_args_str
=
"-"
.
join
(
args_split
[
1
:])
col_widths
[
0
]
=
max
(
col_widths
[
0
],
len
(
func_name
)
+
2
,
len
(
headers
[
0
])
+
2
)
col_widths
[
1
]
=
max
(
col_widths
[
1
],
len
(
input_args_str
)
+
2
,
len
(
headers
[
1
])
+
2
)
col_widths
[
2
]
=
max
(
col_widths
[
2
],
len
(
f
"
{
values
[
'BitBLAS_top20_latency'
]:.
3
f
}
ms"
)
+
2
,
len
(
headers
[
2
])
+
2
,
)
# break only if you want to measure widths from a single example;
# otherwise, let it loop over all items.
# Print header
for
i
,
header
in
enumerate
(
headers
):
headers
[
i
]
=
header
.
ljust
(
col_widths
[
i
])
print
(
""
.
join
(
headers
))
print
(
"-"
*
sum
(
col_widths
))
# Print rows
for
config_key
,
values
in
benchmark_results
.
items
():
args_split
=
config_key
.
split
(
"-"
)
func_name
=
args_split
[
0
]
input_args_str
=
"-"
.
join
(
args_split
[
1
:])
row
=
[
func_name
,
input_args_str
,
f
"
{
values
[
'BitBLAS_top20_latency'
]:.
3
f
}
ms"
,
]
row_str
=
""
.
join
(
[
str
(
cell
).
ljust
(
col_widths
[
idx
])
for
idx
,
cell
in
enumerate
(
row
)]
)
print
(
row_str
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_cutlass_fp4_moe.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe
kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit
activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8)
and 16-bit activations.
"""
import
nvtx
import
torch
import
torch.utils.benchmark
as
benchmark
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp4
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
WEIGHT_SHAPES_MOE
=
{
"nvidia/DeepSeek-R1-FP4"
:
[
[
256
,
8
,
2048
,
7168
],
],
}
DEFAULT_MODELS
=
[
"nvidia/DeepSeek-R1-FP4"
,
]
DEFAULT_BATCH_SIZES
=
[
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
]
DEFAULT_TP_SIZES
=
[
1
]
PER_ACT_TOKEN_OPTS
=
[
False
]
PER_OUT_CH_OPTS
=
[
False
]
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
def
to_fp8
(
tensor
:
torch
.
Tensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
bench_run
(
results
:
list
[
benchmark
.
Measurement
],
model
:
str
,
num_experts
:
int
,
topk
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
mkn
:
tuple
[
int
,
int
,
int
],
):
label
=
"NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton"
sub_label
=
(
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})"
.
format
(
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
)
)
print
(
f
"Testing:
{
sub_label
}
"
)
(
m
,
k
,
n
)
=
mkn
dtype
=
torch
.
half
device
=
"cuda"
a
=
torch
.
randn
((
m
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
num_experts
,
2
*
n
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
num_experts
,
k
,
n
),
device
=
device
,
dtype
=
dtype
)
/
10
_
,
a_fp8_scale
=
ops
.
scaled_fp8_quant
(
a
)
w1_fp8q
=
torch
.
empty
(
(
num_experts
,
2
*
n
,
k
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
w2_fp8q
=
torch
.
empty
((
num_experts
,
k
,
n
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
w1_fp8scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
w2_fp8scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
num_experts
):
w1_fp8q
[
expert
],
w1_fp8scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w1
[
expert
])
w2_fp8q
[
expert
],
w2_fp8scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w2
[
expert
])
w1_fp8q_notransp
=
w1_fp8q
.
clone
()
w2_fp8q_notransp
=
w2_fp8q
.
clone
()
w1_fp8q
=
w1_fp8q
.
transpose
(
1
,
2
)
w2_fp8q
=
w2_fp8q
.
transpose
(
1
,
2
)
score
=
torch
.
randn
((
m
,
num_experts
),
device
=
device
,
dtype
=
dtype
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
False
)
quant_blocksize
=
16
w1_blockscale
=
torch
.
empty
(
(
num_experts
,
2
*
n
,
k
//
quant_blocksize
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
,
)
w2_blockscale
=
torch
.
empty
(
(
num_experts
,
k
,
n
//
quant_blocksize
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
# n_b_scales = 2 * n if per_out_ch else 1
# k_b_scales = k if per_out_ch else 1
w1_fp4
=
torch
.
empty
((
num_experts
,
2
*
n
,
k
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
w2_fp4
=
torch
.
empty
((
num_experts
,
k
,
n
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
w1_gs
=
torch
.
empty
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
float32
)
w2_gs
=
torch
.
empty
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
float32
)
a1_gs
=
torch
.
ones
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
float32
)
a2_gs
=
torch
.
ones
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
num_experts
):
w1_e
=
w1
[
expert
]
w2_e
=
w2
[
expert
]
w1_amax
=
torch
.
abs
(
w1_e
).
max
().
to
(
torch
.
float32
)
w2_amax
=
torch
.
abs
(
w2_e
).
max
().
to
(
torch
.
float32
)
w1_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w1_amax
w2_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w2_amax
w1_fp4
[
expert
],
w1_blockscale
[
expert
]
=
ops
.
scaled_fp4_quant
(
w1_e
,
w1_gs
[
expert
]
)
w2_fp4
[
expert
],
w2_blockscale
[
expert
]
=
ops
.
scaled_fp4_quant
(
w2_e
,
w2_gs
[
expert
]
)
def
run_triton_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_fp8_scale
:
torch
.
Tensor
,
num_repeats
:
int
,
):
for
_
in
range
(
num_repeats
):
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_fp8_scale
,
)
def
run_cutlass_moe_fp4
(
a
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w1_gs
:
torch
.
Tensor
,
w2_gs
:
torch
.
Tensor
,
a1_gs
:
torch
.
Tensor
,
a2_gs
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
,
num_repeats
:
int
,
):
for
_
in
range
(
num_repeats
):
with
nvtx
.
annotate
(
"cutlass_moe_fp4"
,
color
=
"green"
):
cutlass_moe_fp4
(
a
=
a
,
a1_gscale
=
a1_gs
,
a2_gscale
=
a2_gs
,
w1_fp4
=
w1_fp4
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
w1_gs
,
w2_fp4
=
w2_fp4
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
w2_gs
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
num_experts
,
device
=
device
,
)
def
run_cutlass_from_graph
(
a
:
torch
.
Tensor
,
a1_gscale
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_alphas
:
torch
.
Tensor
,
a2_gscale
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w2_alphas
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
,
):
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
):
return
cutlass_moe_fp4
(
a
=
a
,
a1_gscale
=
a1_gs
,
w1_fp4
=
w1_fp4
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
w1_alphas
,
a2_gscale
=
a2_gs
,
w2_fp4
=
w2_fp4
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
w2_alphas
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
num_experts
,
device
=
device
,
)
def
run_triton_from_graph
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_fp8_scale
:
torch
.
Tensor
,
):
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
):
return
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_fp8_scale
,
)
def
replay_graph
(
graph
,
num_repeats
):
for
_
in
range
(
num_repeats
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
cutlass_stream
=
torch
.
cuda
.
Stream
()
cutlass_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
cutlass_graph
,
stream
=
cutlass_stream
):
run_cutlass_from_graph
(
a
=
a
,
a1_gscale
=
a1_gs
,
w1_fp4
=
w1_fp4
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
w1_gs
,
a2_gscale
=
a2_gs
,
w2_fp4
=
w2_fp4
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
w2_gs
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
num_experts
,
device
=
device
,
)
torch
.
cuda
.
synchronize
()
triton_stream
=
torch
.
cuda
.
Stream
()
triton_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
triton_graph
,
stream
=
triton_stream
):
run_triton_from_graph
(
a
,
w1_fp8q_notransp
,
w2_fp8q_notransp
,
topk_weights
,
topk_ids
,
w1_fp8scale
,
w2_fp8scale
,
a_fp8_scale
,
)
torch
.
cuda
.
synchronize
()
min_run_time
=
5
num_warmup
=
5
num_runs
=
25
globals
=
{
# Baseline params
"w1"
:
w1
,
"w2"
:
w2
,
"score"
:
score
,
"topk"
:
topk
,
"w1_fp8q_notransp"
:
w1_fp8q_notransp
,
"w2_fp8q_notransp"
:
w2_fp8q_notransp
,
"w1_fp8scale"
:
w1_fp8scale
,
"w2_fp8scale"
:
w2_fp8scale
,
"a_fp8_scale"
:
a_fp8_scale
,
# Cutlass params
"a"
:
a
,
"a1_gscale"
:
a1_gs
,
"w1_fp4"
:
w1_fp4
,
"w1_blockscale"
:
w1_blockscale
,
"w1_alphas"
:
w1_gs
,
"a2_gscale"
:
a2_gs
,
"w2_fp4"
:
w2_fp4
,
"w2_blockscale"
:
w2_blockscale
,
"w2_alphas"
:
w2_gs
,
"topk_weights"
:
topk_weights
,
"topk_ids"
:
topk_ids
,
"m"
:
m
,
"n"
:
n
,
"k"
:
k
,
"e"
:
num_experts
,
"device"
:
device
,
# cuda graph params
"cutlass_graph"
:
cutlass_graph
,
"triton_graph"
:
triton_graph
,
# Gen params
"num_runs"
:
num_runs
,
# Kernels
"run_triton_moe"
:
run_triton_moe
,
"run_cutlass_moe_fp4"
:
run_cutlass_moe_fp4
,
"replay_graph"
:
replay_graph
,
}
# Warmup
run_triton_moe
(
a
,
w1_fp8q_notransp
,
w2_fp8q_notransp
,
topk_weights
,
topk_ids
,
w1_fp8scale
,
w2_fp8scale
,
a_fp8_scale
,
num_warmup
,
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
replay_graph
(
triton_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(triton_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
run_cutlass_moe_fp4
(
a
,
w1_fp4
,
w2_fp4
,
w1_blockscale
,
w2_blockscale
,
w1_gs
,
w2_gs
,
a1_gs
,
a2_gs
,
topk_weights
,
topk_ids
,
m
,
n
,
k
,
num_experts
,
device
,
num_warmup
,
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"cutlass_moe_fp4"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
replay_graph
(
cutlass_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(cutlass_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"cutlass_moe_fp4_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
def
main
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
results
:
list
[
benchmark
.
Measurement
]
=
[]
for
model
in
args
.
models
:
for
tp
in
args
.
tp_sizes
:
for
layer
in
WEIGHT_SHAPES_MOE
[
model
]:
num_experts
=
layer
[
0
]
topk
=
layer
[
1
]
size_k
=
layer
[
2
]
size_n
=
layer
[
3
]
//
tp
if
len
(
args
.
limit_k
)
>
0
and
size_k
not
in
args
.
limit_k
:
continue
if
len
(
args
.
limit_n
)
>
0
and
size_n
not
in
args
.
limit_n
:
continue
for
per_act_token
in
PER_ACT_TOKEN_OPTS
:
for
per_out_ch
in
PER_OUT_CH_OPTS
:
for
size_m
in
args
.
batch_sizes
:
mkn
=
(
size_m
,
size_k
,
size_n
)
bench_run
(
results
,
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
,
)
compare
=
benchmark
.
Compare
(
results
)
compare
.
print
()
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES_MOE
.
keys
(),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-num-groups"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-act-token"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-out-ch"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
args
=
parser
.
parse_args
()
main
(
args
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.utils.benchmark
as
benchmark
from
benchmark_shapes
import
WEIGHT_SHAPES_MOE
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp8
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_topk
,
)
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"nm-testing/Mixtral-8x7B-Instruct-v0.1"
,
"nm-testing/deepseekv2-lite"
,
"ibm-granite/granite-3.0-1b-a400m"
,
"ibm-granite/granite-3.0-3b-a800m"
,
]
DEFAULT_BATCH_SIZES
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
]
DEFAULT_TP_SIZES
=
[
1
]
PER_ACT_TOKEN_OPTS
=
[
False
]
PER_OUT_CH_OPTS
=
[
False
]
def
to_fp8
(
tensor
:
torch
.
Tensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
bench_run
(
results
:
list
[
benchmark
.
Measurement
],
model
:
str
,
num_experts
:
int
,
topk
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
mkn
:
tuple
[
int
,
int
,
int
],
):
label
=
"Quant Matmul"
sub_label
=
(
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})"
.
format
(
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
)
)
print
(
f
"Testing:
{
sub_label
}
"
)
(
m
,
k
,
n
)
=
mkn
dtype
=
torch
.
half
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
num_experts
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
num_experts
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
_
,
a_scale
=
ops
.
scaled_fp8_quant
(
a
)
w1_q
=
torch
.
empty
(
(
num_experts
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w2_q
=
torch
.
empty
((
num_experts
,
k
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w1_scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
num_experts
):
w1_q
[
expert
],
w1_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w1
[
expert
])
w2_q
[
expert
],
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w2
[
expert
])
score
=
torch
.
randn
((
m
,
num_experts
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
False
)
def
run_triton_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
,
num_repeats
:
int
,
):
for
_
in
range
(
num_repeats
):
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_scale
,
)
def
run_cutlass_moe
(
a
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
per_act_token
:
bool
,
num_repeats
:
int
,
):
for
_
in
range
(
num_repeats
):
cutlass_moe_fp8
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
w1_scale
,
w2_scale
,
per_act_token
,
a1_scale
=
None
,
)
def
run_cutlass_from_graph
(
a
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
,
w1_q
:
torch
.
Tensor
,
w2_q
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
):
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
):
return
cutlass_moe_fp8
(
a
,
w1_q
,
w2_q
,
topk_weights
,
topk_ids
,
w1_scale
,
w2_scale
,
per_act_token
,
a1_scale
=
None
,
)
def
run_triton_from_graph
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
,
):
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
):
return
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_scale
,
)
def
replay_graph
(
graph
,
num_repeats
):
for
_
in
range
(
num_repeats
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
cutlass_stream
=
torch
.
cuda
.
Stream
()
cutlass_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
cutlass_graph
,
stream
=
cutlass_stream
):
run_cutlass_from_graph
(
a
,
a_scale
,
w1_q
,
w2_q
,
w1_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
)
torch
.
cuda
.
synchronize
()
triton_stream
=
torch
.
cuda
.
Stream
()
triton_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
triton_graph
,
stream
=
triton_stream
):
run_triton_from_graph
(
a
,
w1_q
,
w2_q
,
topk_weights
,
topk_ids
,
w1_scale
,
w2_scale
,
a_scale
,
)
torch
.
cuda
.
synchronize
()
min_run_time
=
5
num_warmup
=
5
num_runs
=
25
globals
=
{
# Baseline params
"w1"
:
w1
,
"w2"
:
w2
,
"score"
:
score
,
"topk"
:
topk
,
# Cutlass params
"a_scale"
:
a_scale
,
"w1_q"
:
w1_q
,
"w2_q"
:
w2_q
,
"w1_scale"
:
w1_scale
,
"w2_scale"
:
w2_scale
,
"per_act_token"
:
per_act_token
,
# cuda graph params
"cutlass_graph"
:
cutlass_graph
,
"triton_graph"
:
triton_graph
,
# Gen params
"a"
:
a
,
"topk_weights"
:
topk_weights
,
"topk_ids"
:
topk_ids
,
"num_runs"
:
num_runs
,
# Kernels
"run_triton_moe"
:
run_triton_moe
,
"run_cutlass_moe"
:
run_cutlass_moe
,
"replay_graph"
:
replay_graph
,
}
# Warmup
run_triton_moe
(
a
,
w1_q
,
w2_q
,
topk_weights
,
topk_ids
,
w1_scale
,
w2_scale
,
a_scale
,
num_warmup
,
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
replay_graph
(
triton_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(triton_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
run_cutlass_moe
(
a
,
a_scale
,
w1_q
,
w2_q
,
w1_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
per_act_token
,
num_warmup
,
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"grouped_gemm_moe"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
replay_graph
(
cutlass_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(cutlass_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"grouped_gemm_moe_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
def
main
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
results
:
list
[
benchmark
.
Measurement
]
=
[]
for
model
in
args
.
models
:
for
tp
in
args
.
tp_sizes
:
for
layer
in
WEIGHT_SHAPES_MOE
[
model
]:
num_experts
=
layer
[
0
]
topk
=
layer
[
1
]
size_k
=
layer
[
2
]
size_n
=
layer
[
3
]
//
tp
if
len
(
args
.
limit_k
)
>
0
and
size_k
not
in
args
.
limit_k
:
continue
if
len
(
args
.
limit_n
)
>
0
and
size_n
not
in
args
.
limit_n
:
continue
for
per_act_token
in
PER_ACT_TOKEN_OPTS
:
for
per_out_ch
in
PER_OUT_CH_OPTS
:
for
size_m
in
DEFAULT_BATCH_SIZES
:
mkn
=
(
size_m
,
size_k
,
size_n
)
bench_run
(
results
,
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
,
)
compare
=
benchmark
.
Compare
(
results
)
compare
.
print
()
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark Marlin across specified models/shapes/batches"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES_MOE
.
keys
(),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-num-groups"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-act-token"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-out-ch"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
args
=
parser
.
parse_args
()
main
(
args
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_layernorm.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
torch
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
@
torch
.
inference_mode
()
def
main
(
num_tokens
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
do_profile
:
bool
=
False
,
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
"cuda"
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
x
*=
scale
residual
=
torch
.
randn_like
(
x
)
*
scale
if
add_residual
else
None
def
run_cuda_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
num_iters
):
layer
(
x
,
residual
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
print
(
"Warming up..."
)
run_benchmark
=
run_cuda_benchmark
run_benchmark
(
num_iters
=
num_warmup_iters
,
profile
=
False
)
# Benchmark.
if
do_profile
:
latency
=
run_benchmark
(
num_iters
=
1
,
profile
=
True
)
else
:
latency
=
run_benchmark
(
num_iters
=
num_iters
,
profile
=
False
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the layernorm kernel."
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--add-residual"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-warmup-iters"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-iters"
,
type
=
int
,
default
=
100
,
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
,
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
num_tokens
=
args
.
num_tokens
,
hidden_size
=
args
.
hidden_size
,
add_residual
=
args
.
add_residual
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
,
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_lora.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
json
import
pickle
import
time
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
itertools
import
product
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Optional
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
utils
import
ArgPool
,
Bench
,
CudaGraphBenchParams
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.lora.ops.triton_ops
import
LoRAKernelMeta
,
lora_expand
,
lora_shrink
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_TP_SIZES
=
[
1
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
192
,
256
,
320
,
384
,
448
,
512
,
640
,
768
,
896
,
1024
,
2048
,
3072
,
4096
,
5120
,
6144
,
7168
,
8192
,
]
DEFAULT_HIDDEN_SIZES
=
[
1024
,
2048
,
4096
,
8192
,
16384
]
DEFAULT_LORA_RANKS
=
[
16
]
DEFAULT_NUM_LORAS
=
[
1
,
2
,
3
,
4
]
DEFAULT_SORT_BY_LORA_IDS
=
[
False
,
True
]
DEFAULT_SEQ_LENGTHS
=
[
1
]
DEFAULT_EXPAND_FN_ADD_INPUTS
=
[
True
,
False
]
# Utilities
def
dtype_to_str
(
dtype
:
torch
.
dtype
):
if
dtype
==
torch
.
float16
:
return
"f16"
if
dtype
==
torch
.
bfloat16
:
return
"bf16"
if
dtype
==
torch
.
float32
:
return
"f32"
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
"
)
def
make_rand_lora_weight_tensor
(
k
:
int
,
n
:
int
,
num_loras
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
=
"cuda"
)
->
torch
.
Tensor
:
# LoRA weights column major
return
torch
.
rand
((
num_loras
,
n
,
k
),
dtype
=
dtype
).
to
(
device
)
def
make_rand_tensors
(
a_shape
:
tuple
[
int
],
b_shape
:
tuple
[
int
],
c_shape
:
tuple
[
int
],
a_dtype
:
torch
.
dtype
,
b_dtype
:
torch
.
dtype
,
c_dtype
:
torch
.
dtype
,
num_slices
:
int
,
device
:
str
=
"cuda"
,
)
->
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
"""
Make LoRA input/output matrices.
"""
A
=
torch
.
rand
(
a_shape
,
dtype
=
a_dtype
).
to
(
device
)
# LoRA weights column major
Bs
=
[
torch
.
rand
(
b_shape
,
dtype
=
b_dtype
).
to
(
device
)
for
_
in
range
(
num_slices
)]
C
=
torch
.
zeros
(
c_shape
,
dtype
=
c_dtype
).
to
(
device
)
return
A
,
Bs
,
C
def
make_prompt_lora_mapping
(
num_prompts
:
int
,
num_active_loras
:
int
,
sort_by_lora_id
:
bool
,
device
:
str
)
->
torch
.
Tensor
:
"""
All prompts are mapped to a LoRA ID in range [0, num_active_loras).
where 0 refers to first lora, 1 refers to second lora and so on.
"""
assert
num_active_loras
>
0
if
not
sort_by_lora_id
:
return
torch
.
randint
(
0
,
num_active_loras
,
(
num_prompts
,),
dtype
=
torch
.
long
)
# Divide LoRAs equally and in order.
part_size
=
num_prompts
//
num_active_loras
part_size
=
max
(
part_size
,
1
)
lora_id
=
0
prompt_lora_mapping
=
[]
while
len
(
prompt_lora_mapping
)
<
num_prompts
:
prompt_lora_mapping
.
extend
([
lora_id
]
*
part_size
)
lora_id
=
lora_id
+
1
if
lora_id
+
1
<
num_active_loras
else
lora_id
return
torch
.
tensor
(
prompt_lora_mapping
[:
num_prompts
],
dtype
=
torch
.
long
,
device
=
device
)
def
make_token_lora_mapping
(
num_tokens
:
int
,
num_prompts
:
int
,
prompt_lora_mapping
:
torch
.
Tensor
,
seq_len_tensor
:
torch
.
Tensor
,
device
:
str
,
):
"""
Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor
"""
assert
prompt_lora_mapping
.
shape
[
0
]
==
num_prompts
# token to lora index mapping
token_lora_mapping
=
[
0
]
*
num_tokens
current_offset
=
0
for
b_id
in
range
(
num_prompts
):
lora_index
=
prompt_lora_mapping
[
b_id
].
item
()
s
=
current_offset
e
=
s
+
seq_len_tensor
[
b_id
].
item
()
token_lora_mapping
[
s
:
e
]
=
[
lora_index
]
*
(
e
-
s
)
current_offset
+=
seq_len_tensor
[
b_id
].
item
()
return
torch
.
tensor
(
token_lora_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
def
ref_group_gemm
(
ref_out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
lora_weights
:
list
[
torch
.
Tensor
],
seq_lens_cpu
:
torch
.
Tensor
,
prompt_lora_mapping_cpu
:
torch
.
Tensor
,
scaling
:
float
,
add_inputs
:
Optional
[
bool
],
):
"""
Torch group gemm reference implementation to test correctness of
benchmarking operations.
"""
batches
=
seq_lens_cpu
.
size
(
0
)
out_list
=
[]
current_offset
=
0
for
lora_index
,
b_length
in
zip
(
range
(
batches
),
seq_lens_cpu
):
x
=
input
[
current_offset
:
b_length
+
current_offset
,
:]
current_offset
+=
b_length
w
=
lora_weights
[
prompt_lora_mapping_cpu
[
lora_index
]]
result
=
torch
.
nn
.
functional
.
linear
(
x
,
w
)
result
*=
scaling
out_list
.
append
(
result
)
cat_result
=
torch
.
cat
(
out_list
,
dim
=
0
)
if
add_inputs
:
ref_out
+=
cat_result
else
:
ref_out
.
copy_
(
cat_result
)
class
OpType
(
Enum
):
"""
LoRA Ops to benchmark and its properties.
"""
LORA_SHRINK
=
auto
()
LORA_EXPAND
=
auto
()
@
staticmethod
def
from_str
(
s
:
str
)
->
"OpType"
:
if
s
.
lower
()
==
"lora_shrink"
:
return
OpType
.
LORA_SHRINK
if
s
.
lower
()
==
"lora_expand"
:
return
OpType
.
LORA_EXPAND
raise
ValueError
(
f
"Unrecognized str
{
s
}
to convert to OpType"
)
def
is_shrink_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
LORA_SHRINK
]
def
is_expand_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
LORA_EXPAND
]
def
num_slices
(
self
)
->
list
[
int
]:
return
[
1
,
2
,
3
]
def
mkn
(
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
)
->
tuple
[
int
,
int
,
int
]:
num_tokens
=
batch_size
*
seq_length
if
self
.
is_shrink_fn
():
m
=
num_tokens
k
=
hidden_size
n
=
lora_rank
else
:
assert
self
.
is_expand_fn
()
m
=
num_tokens
k
=
lora_rank
n
=
hidden_size
return
m
,
k
,
n
def
matmul_dtypes
(
self
,
op_dtype
:
torch
.
dtype
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
,
torch
.
dtype
]:
"""
return a type, b type and c type for A x B = C
"""
if
self
.
is_shrink_fn
():
return
op_dtype
,
op_dtype
,
torch
.
float32
else
:
assert
self
.
is_expand_fn
()
return
torch
.
float32
,
op_dtype
,
op_dtype
def
matmul_shapes
(
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
,
num_loras
:
int
,
num_slices
:
int
,
)
->
tuple
[
tuple
[
int
],
tuple
[
int
],
tuple
[
int
]]:
"""
Given num_slices, return the shapes of the A, B, and C matrices
in A x B = C, for the op_type
"""
m
,
k
,
n
=
self
.
mkn
(
batch_size
,
seq_length
,
hidden_size
,
lora_rank
)
b_shape
=
(
num_loras
,
n
,
k
)
# col-major
if
self
in
[
OpType
.
LORA_SHRINK
]:
# LoRA shrink kernels support num_slices inherently in the kernel.
return
((
m
,
k
),
b_shape
,
(
num_slices
,
m
,
n
))
if
self
in
[
OpType
.
LORA_EXPAND
]:
# LoRA expand kernels support num_slices inherently in the kernel
return
((
num_slices
,
m
,
k
),
b_shape
,
(
m
,
n
*
num_slices
))
raise
ValueError
(
f
"Unrecognized op_type
{
self
}
"
)
def
bench_fn
(
self
)
->
Callable
:
if
self
==
OpType
.
LORA_SHRINK
:
return
lora_shrink
if
self
==
OpType
.
LORA_EXPAND
:
return
lora_expand
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
def
run_ref_group_gemm
(
self
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
lora_weights
:
list
[
torch
.
Tensor
],
**
kwargs
,
)
->
Callable
:
"""Each benchmark operation expects the input, lora_weights and outputs
in a slightly different format. Refer to self.matmul_shapes().
run_ref_group_gemm accounts for those differences in executing a
reference group gemm for correctness testing.
"""
w_dtype
=
lora_weights
[
0
].
dtype
num_slices
=
len
(
lora_weights
)
if
self
in
[
OpType
.
LORA_SHRINK
]:
for
slice_idx
in
range
(
num_slices
):
ref_group_gemm
(
ref_out
=
output
[
slice_idx
,
:],
input
=
input
,
lora_weights
=
lora_weights
[
slice_idx
],
**
kwargs
,
)
elif
self
in
[
OpType
.
LORA_EXPAND
]:
hidden_size
=
lora_weights
[
0
].
shape
[
1
]
for
slice_idx
in
range
(
num_slices
):
slice_offset
=
slice_idx
*
hidden_size
ref_group_gemm
(
ref_out
=
output
[:,
slice_offset
:
slice_offset
+
hidden_size
],
input
=
input
[
slice_idx
].
clone
().
to
(
dtype
=
w_dtype
),
lora_weights
=
lora_weights
[
slice_idx
],
**
kwargs
,
)
else
:
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
@
dataclass
class
BenchmarkContext
:
"""
LoRA benchmark context
"""
batch_size
:
int
hidden_size
:
int
num_loras
:
int
num_active_loras
:
int
lora_rank
:
int
sort_by_lora_id
:
bool
dtype
:
torch
.
dtype
seq_length
:
Optional
[
int
]
=
None
num_slices
:
Optional
[
int
]
=
None
# num_slices for slice based ops
def
with_seq_length
(
self
,
seq_length
:
int
)
->
"BenchmarkContext"
:
ctx
=
copy
.
copy
(
self
)
ctx
.
seq_length
=
seq_length
return
ctx
def
with_num_slices
(
self
,
num_slices
:
int
)
->
"BenchmarkContext"
:
ctx
=
copy
.
copy
(
self
)
ctx
.
num_slices
=
num_slices
return
ctx
def
bench_label
(
self
)
->
str
:
return
f
"lora-
{
self
.
dtype
}
"
def
bench_sublabel
(
self
,
op_type
:
OpType
)
->
str
:
m
,
k
,
n
=
op_type
.
mkn
(
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
,
self
.
lora_rank
)
desc
=
{
"bs"
:
self
.
batch_size
,
"sl"
:
self
.
seq_length
,
"m"
:
m
,
"k"
:
k
,
"n"
:
n
,
"num_loras"
:
self
.
num_loras
,
"sort_by_lora"
:
self
.
sort_by_lora_id
,
"num_slices"
:
self
.
num_slices
,
}
return
json
.
dumps
(
desc
)
@
dataclass
class
BenchmarkTensors
:
"""
Input/Output tensors used for benchmarks
"""
# matmul tensors
input
:
torch
.
Tensor
lora_weights_lst
:
list
[
torch
.
Tensor
]
output
:
torch
.
Tensor
# LoRA kernel metadata
lora_kernel_meta
:
LoRAKernelMeta
# Metadata tensors used in testing correctness
seq_lens
:
torch
.
Tensor
prompt_lora_mapping
:
torch
.
Tensor
def
io_types
(
self
)
->
str
:
return
(
f
"
{
dtype_to_str
(
self
.
input
.
dtype
)
}
x"
f
"
{
dtype_to_str
(
self
.
lora_weights_lst
[
0
].
dtype
)
}
=>"
f
"
{
dtype_to_str
(
self
.
output
.
dtype
)
}
"
)
@
staticmethod
def
make
(
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
device
:
str
=
"cuda"
)
->
"BenchmarkTensors"
:
# Make input / output matmul tensors.
a_shape
,
b_shape
,
c_shape
=
op_type
.
matmul_shapes
(
ctx
.
batch_size
,
ctx
.
seq_length
,
ctx
.
hidden_size
,
ctx
.
lora_rank
,
ctx
.
num_loras
,
ctx
.
num_slices
,
)
a_type
,
b_type
,
c_type
=
op_type
.
matmul_dtypes
(
ctx
.
dtype
)
input_tensor
,
lora_weights
,
output_tensor
=
make_rand_tensors
(
a_shape
,
b_shape
,
c_shape
,
a_type
,
b_type
,
c_type
,
num_slices
=
ctx
.
num_slices
)
# Make metadata tensors.
# Keep the metadata tensors in the CPU for further processing if needed.
# The tensors get moved to the GPU before benchmarking.
assert
ctx
.
num_active_loras
<=
ctx
.
num_loras
total_tokens
=
ctx
.
batch_size
*
ctx
.
seq_length
# Make metadata tensors involved in correctness testing.
# Prepare seq lens tensor
seq_len_tensor
=
torch
.
randint
(
ctx
.
seq_length
,
ctx
.
seq_length
+
1
,
(
ctx
.
batch_size
,)
)
assert
total_tokens
==
seq_len_tensor
.
sum
()
# Prepare prompt lora indices tensor
prompt_lora_indices_tensor
=
make_prompt_lora_mapping
(
ctx
.
batch_size
,
ctx
.
num_active_loras
,
ctx
.
sort_by_lora_id
,
"cpu"
)
# Make LoRAKernelMeta
token_lora_indices_tensor
=
make_token_lora_mapping
(
total_tokens
,
ctx
.
batch_size
,
prompt_lora_indices_tensor
,
seq_len_tensor
,
"cpu"
,
)
lora_kernel_meta
=
LoRAKernelMeta
.
make
(
max_loras
=
ctx
.
num_loras
,
max_num_tokens
=
token_lora_indices_tensor
.
size
(
0
),
device
=
"cpu"
,
)
lora_kernel_meta
.
prepare_tensors
(
token_lora_mapping
=
token_lora_indices_tensor
)
return
BenchmarkTensors
(
input_tensor
,
lora_weights
,
output_tensor
,
lora_kernel_meta
,
seq_len_tensor
,
prompt_lora_indices_tensor
,
)
def
sanity_check
(
self
)
->
None
:
"""
Fails asserts when non-conformality is detected.
"""
num_tokens
=
self
.
input
.
shape
[
-
2
]
# check metadata tensors
assert
torch
.
sum
(
self
.
seq_lens
)
==
num_tokens
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
# assert self.seq_start_loc.shape[0] == num_seqs
assert
self
.
prompt_lora_mapping
.
shape
[
0
]
==
num_seqs
assert
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
]
==
num_tokens
def
to_device
(
self
,
device
:
str
):
"""
Transfer tensors to device if the tensors aren't already on the device
"""
def
to_device
(
tensor
:
torch
.
Tensor
):
if
tensor
.
device
!=
device
:
tensor
=
tensor
.
to
(
device
=
device
)
return
tensor
self
.
input
=
to_device
(
self
.
input
)
self
.
output
=
to_device
(
self
.
output
)
self
.
seq_lens
=
to_device
(
self
.
seq_lens
)
self
.
prompt_lora_mapping
=
to_device
(
self
.
prompt_lora_mapping
)
for
i
in
range
(
len
(
self
.
lora_weights_lst
)):
self
.
lora_weights_lst
[
i
]
=
to_device
(
self
.
lora_weights_lst
[
i
])
# LoRA meta
for
field_name
in
LoRAKernelMeta
.
__dataclass_fields__
:
field
=
getattr
(
self
.
lora_kernel_meta
,
field_name
)
assert
isinstance
(
field
,
torch
.
Tensor
)
setattr
(
self
.
lora_kernel_meta
,
field_name
,
to_device
(
field
))
def
metadata
(
self
)
->
tuple
[
int
,
int
,
int
]:
"""
Return num_seqs, num_tokens and max_seq_len
"""
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
num_tokens
=
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
]
max_seq_len
=
torch
.
max
(
self
.
seq_lens
).
item
()
num_slices
=
len
(
self
.
lora_weights_lst
)
return
num_seqs
,
num_tokens
,
max_seq_len
,
num_slices
def
as_lora_shrink_kwargs
(
self
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
()
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
()
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape [num_tokens, hidden_size]
assert
len
(
i_shape
)
==
2
assert
i_shape
[
0
]
==
num_tokens
hidden_size
=
i_shape
[
1
]
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
assert
len
(
lw_shape
)
==
3
assert
lw_shape
[
2
]
==
hidden_size
lora_rank
=
lw_shape
[
1
]
# Expected output shape [num_slices, num_tokens, lora_rank]
assert
len
(
o_shape
)
==
3
assert
o_shape
==
(
num_slices
,
num_tokens
,
lora_rank
)
return
{
"inputs"
:
self
.
input
,
"lora_a_weights"
:
self
.
lora_weights_lst
,
"output_tensor"
:
self
.
output
,
"token_lora_mapping"
:
self
.
lora_kernel_meta
.
token_lora_mapping
,
"token_indices_sorted_by_lora_ids"
:
(
self
.
lora_kernel_meta
.
token_indices_sorted_by_lora_ids
),
"num_tokens_per_lora"
:
self
.
lora_kernel_meta
.
num_tokens_per_lora
,
"lora_token_start_loc"
:
self
.
lora_kernel_meta
.
lora_token_start_loc
,
"lora_ids"
:
self
.
lora_kernel_meta
.
active_lora_ids
,
"scaling"
:
1.0
,
}
def
as_lora_expand_kwargs
(
self
,
add_inputs
:
bool
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
()
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
()
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape : [num_slices, num_tokens, lora_rank]
assert
len
(
i_shape
)
==
3
assert
i_shape
[
0
]
==
num_slices
assert
i_shape
[
1
]
==
num_tokens
lora_rank
=
i_shape
[
2
]
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
assert
len
(
lw_shape
)
==
3
assert
lw_shape
[
2
]
==
lora_rank
hidden_size
=
lw_shape
[
1
]
# Expected output shape : [num_tokens, hidden_size * num_slices]
assert
len
(
o_shape
)
==
2
assert
o_shape
==
(
num_tokens
,
hidden_size
*
num_slices
)
return
{
"inputs"
:
self
.
input
,
"lora_b_weights"
:
self
.
lora_weights_lst
,
"output_tensor"
:
self
.
output
,
"token_lora_mapping"
:
self
.
lora_kernel_meta
.
token_lora_mapping
,
"token_indices_sorted_by_lora_ids"
:
(
self
.
lora_kernel_meta
.
token_indices_sorted_by_lora_ids
),
"num_tokens_per_lora"
:
self
.
lora_kernel_meta
.
num_tokens_per_lora
,
"lora_token_start_loc"
:
self
.
lora_kernel_meta
.
lora_token_start_loc
,
"lora_ids"
:
self
.
lora_kernel_meta
.
active_lora_ids
,
"offset_start"
:
0
,
"add_inputs"
:
add_inputs
,
}
def
bench_fn_kwargs
(
self
,
op_type
:
OpType
,
add_inputs
:
Optional
[
bool
]
=
None
)
->
dict
[
str
,
Any
]:
if
op_type
.
is_shrink_fn
():
assert
add_inputs
is
None
else
:
assert
add_inputs
is
not
None
if
op_type
==
OpType
.
LORA_SHRINK
:
return
self
.
as_lora_shrink_kwargs
()
if
op_type
==
OpType
.
LORA_EXPAND
:
return
self
.
as_lora_expand_kwargs
(
add_inputs
)
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
def
test_correctness
(
self
,
op_type
:
OpType
,
expand_fn_add_inputs
:
Optional
[
bool
]
)
->
bool
:
"""
Test correctness of op_type implementation against a grouped gemm
reference implementation.
"""
seq_lens_cpu
=
self
.
seq_lens
.
to
(
device
=
"cpu"
)
prompt_lora_mapping_cpu
=
self
.
prompt_lora_mapping
.
to
(
device
=
"cpu"
)
ref_output
=
self
.
output
.
clone
()
self
.
output
.
zero_
()
op_type
.
bench_fn
()(
**
self
.
bench_fn_kwargs
(
op_type
,
expand_fn_add_inputs
))
op_type
.
run_ref_group_gemm
(
ref_output
,
self
.
input
,
self
.
lora_weights_lst
,
seq_lens_cpu
=
seq_lens_cpu
,
prompt_lora_mapping_cpu
=
prompt_lora_mapping_cpu
,
scaling
=
1.0
,
add_inputs
=
expand_fn_add_inputs
,
)
rtol
,
atol
=
{
torch
.
float16
:
(
6e-2
,
6e-2
),
torch
.
bfloat16
:
(
6e-2
,
6e-2
),
torch
.
float32
:
(
1e-2
,
1e-2
),
}[
self
.
output
.
dtype
]
return
torch
.
allclose
(
ref_output
,
self
.
output
,
rtol
=
rtol
,
atol
=
atol
)
def
bench_optype
(
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
op_type
:
OpType
,
cuda_graph_nops
:
Optional
[
int
]
=
None
,
expand_fn_add_inputs
:
Optional
[
bool
]
=
None
,
test_correctness
:
bool
=
False
,
)
->
TMeasurement
:
assert
arg_pool_size
>=
1
if
op_type
.
is_shrink_fn
():
assert
expand_fn_add_inputs
is
None
else
:
assert
expand_fn_add_inputs
is
not
None
# BenchmarkContext -> BenchmarkTensors
bench_tensors
:
list
[
BenchmarkTensors
]
=
[
BenchmarkTensors
.
make
(
ctx
,
op_type
)
for
_
in
range
(
arg_pool_size
)
]
for
bt
in
bench_tensors
:
bt
.
sanity_check
()
# Test correctness of our implementation.
if
test_correctness
:
assert
all
(
[
bt
.
test_correctness
(
op_type
,
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
)
# BenchmarkTensors -> dict (kwargs)
kwargs_list
=
[
bt
.
bench_fn_kwargs
(
op_type
,
add_inputs
=
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
# Clear LoRA optimization hash-maps.
_LORA_A_PTR_DICT
.
clear
()
_LORA_B_PTR_DICT
.
clear
()
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup
for
kwargs
in
kwargs_list
:
op_type
.
bench_fn
()(
**
kwargs
)
torch
.
cuda
.
synchronize
()
# Merge into a single kwargs and qualify arguments as ArgPool
kwargs
=
{
k
:
ArgPool
([])
for
k
in
kwargs_list
[
0
]}
for
_kwargs
in
kwargs_list
:
for
k
,
v
in
_kwargs
.
items
():
kwargs
[
k
].
values
.
append
(
v
)
describe_args
=
(
f
"add_inputs=
{
expand_fn_add_inputs
}
"
if
expand_fn_add_inputs
is
not
None
else
""
)
description
=
f
"
{
op_type
.
name
}
(
{
describe_args
}
) (
{
bench_tensors
[
0
].
io_types
()
}
)"
cuda_graph_params
=
None
if
cuda_graph_nops
:
cuda_graph_params
=
CudaGraphBenchParams
(
cuda_graph_nops
)
timer
=
None
with
Bench
(
cuda_graph_params
,
ctx
.
bench_label
(),
ctx
.
bench_sublabel
(
op_type
),
description
,
op_type
.
bench_fn
(),
**
kwargs
,
)
as
bench
:
timer
=
bench
.
run
()
return
timer
def
bench_torch_mm
(
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
op_type
:
OpType
,
cuda_graph_nops
:
Optional
[
int
]
=
None
,
)
->
TMeasurement
:
"""
Benchmark basic torch.mm as a roofline.
When all the input tokens have the same LoRA ID, the LoRA kernels are just
a matmul. This torch.mm benchmark serves as a roofline for that case.
input op_type is used in determining the m, k, n dimensions for the matmul.
"""
batch_size
,
hidden_size
,
lora_rank
,
seq_length
,
dtype
=
(
ctx
.
batch_size
,
ctx
.
hidden_size
,
ctx
.
lora_rank
,
ctx
.
seq_length
,
ctx
.
dtype
,
)
m
,
k
,
n
=
op_type
.
mkn
(
batch_size
,
seq_length
,
hidden_size
,
lora_rank
)
# For a fairer comparison.
n
=
n
*
ctx
.
num_slices
# Get matmul input and output tensors for A x B = C
As
,
Bs
,
Cs
=
[],
[],
[]
for
_
in
range
(
arg_pool_size
):
As
.
append
(
torch
.
rand
((
m
,
k
),
dtype
=
dtype
).
to
(
"cuda"
))
Bs
.
append
(
torch
.
rand
((
n
,
k
),
dtype
=
dtype
).
to
(
"cuda"
).
t
())
Cs
.
append
(
torch
.
rand
((
m
,
n
),
dtype
=
dtype
).
to
(
"cuda"
))
# Make torch.mm kwargs
mm_kwargs
=
{
"input"
:
ArgPool
(
As
),
"mat2"
:
ArgPool
(
Bs
),
"out"
:
ArgPool
(
Cs
)}
description
=
(
f
"single-lora roofline using torch.mm (
{
dtype_to_str
(
dtype
)
}
"
f
"x
{
dtype_to_str
(
dtype
)
}
"
f
"=>
{
dtype_to_str
(
dtype
)
}
)"
)
cuda_graph_params
=
None
if
cuda_graph_nops
:
cuda_graph_params
=
CudaGraphBenchParams
(
cuda_graph_nops
)
with
Bench
(
cuda_graph_params
,
ctx
.
bench_label
(),
ctx
.
bench_sublabel
(
op_type
),
description
,
torch
.
mm
,
**
mm_kwargs
,
)
as
bench
:
return
bench
.
run
()
# runner
def
use_cuda_graph_recommendation
()
->
str
:
return
"""
Triton kernels have a significant launch overhead with
launched directly via python. This overhead is more noticeable
for small the problem sizes. For these cases, it is recommended
to use the script with `--cuda-graph-nops N` to benchmark N
consecutive invocations of the benchmarking operations from
inside a CUDA Graph. Note that the returned measurement is for N
invocations of the operation.
"""
def
print_timers
(
timers
:
list
[
TMeasurement
],
args
:
Optional
[
argparse
.
Namespace
]
=
None
):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
if
args
and
args
.
cuda_graph_nops
:
print
(
f
"Note : The timings reported above is for
{
args
.
cuda_graph_nops
}
"
"consecutive invocations of the benchmarking functions. "
f
"Please divide by
{
args
.
cuda_graph_nops
}
for single invocation "
"timings."
)
print
(
"Note on Comparison with torch.mm : The torch.mm numbers are "
"benchmark numbers of a simple matmul emulating the single lora "
"case. It is provided as a roofline for comparing our LoRA Kernel "
"implementations. It is expected that the LoRA kernels will be "
"slower than torch.mm in cases where num_loras is big. But for "
"small num_loras the goal should be to match the torch.mm numbers."
)
def
run
(
args
:
argparse
.
Namespace
,
bench_ctxs
:
list
[
BenchmarkContext
]):
if
args
.
cuda_graph_nops
is
not
None
:
assert
args
.
cuda_graph_nops
>
0
print
(
f
"Benchmarking
{
args
.
cuda_graph_nops
}
invocations inside a CUDA Graph"
)
else
:
print
(
f
"CUDA Graphs not enabled.
\n
{
use_cuda_graph_recommendation
()
}
"
)
timers
=
[]
for
bench_ctx
in
bench_ctxs
:
for
seq_len
in
args
.
seq_lengths
:
bench_ops
:
list
[
OpType
]
=
args
.
op_types
seq_len_timers
=
[]
for
bench_op
in
bench_ops
:
for
num_slices
in
bench_op
.
num_slices
():
_ctx
=
bench_ctx
.
with_seq_length
(
seq_len
).
with_num_slices
(
num_slices
)
# Benchmark torch.mm as a roofline
seq_len_timers
.
append
(
bench_torch_mm
(
_ctx
,
args
.
arg_pool_size
,
bench_op
,
args
.
cuda_graph_nops
)
)
# Benchmark bench_op
expand_fn_add_inputs
=
(
[
None
]
if
bench_op
.
is_shrink_fn
()
else
args
.
expand_fn_add_inputs
)
for
add_input_arg
in
expand_fn_add_inputs
:
seq_len_timers
.
append
(
bench_optype
(
_ctx
,
args
.
arg_pool_size
,
bench_op
,
args
.
cuda_graph_nops
,
add_input_arg
,
args
.
test_correctness
,
)
)
print_timers
(
seq_len_timers
)
timers
.
extend
(
seq_len_timers
)
# Result stdout dump
print
(
"== All Results ===="
)
print_timers
(
timers
,
args
)
if
args
.
output_directory
:
# Result file dump
od
=
Path
(
args
.
output_directory
)
if
not
od
.
exists
():
od
.
mkdir
()
timestamp
=
int
(
time
.
time
())
pkl_file
=
od
/
f
"lora_bench-
{
timestamp
}
.pkl"
print
(
f
"Writing benchmarks to
{
pkl_file
}
"
)
with
open
(
pkl_file
,
"wb"
)
as
f
:
pickle
.
dump
(
timers
,
f
)
def
as_benchmark_contexts
(
hidden_sizes
:
list
[
int
],
lora_ranks
:
list
[
int
],
args
:
argparse
.
Namespace
)
->
list
[
BenchmarkContext
]:
ctxs
:
list
[
BenchmarkContext
]
=
[]
for
batch_size
,
hidden_size
,
lora_rank
,
num_loras
,
sort_by_lora_id
in
product
(
# noqa
args
.
batch_sizes
,
list
(
hidden_sizes
),
lora_ranks
,
args
.
num_loras
,
args
.
sort_by_lora_id
,
):
ctxs
.
append
(
BenchmarkContext
(
batch_size
=
batch_size
,
hidden_size
=
hidden_size
,
lora_rank
=
lora_rank
,
num_loras
=
num_loras
,
num_active_loras
=
args
.
num_active_loras
if
args
.
num_active_loras
else
num_loras
,
# To be filled based on the OpType to benchmark
seq_length
=
None
,
sort_by_lora_id
=
sort_by_lora_id
,
dtype
=
args
.
dtype
,
# To be filled based on the OpType to benchmark
num_slices
=
None
,
)
)
return
ctxs
def
run_list_bench
(
args
:
argparse
.
Namespace
):
print
(
args
)
print
(
"List bench :
\n
"
f
" Hidden Sizes
{
args
.
hidden_sizes
}
"
f
" LoRA Ranks
{
args
.
lora_ranks
}
"
)
# Get all benchmarking contexts
bench_contexts
:
list
[
BenchmarkContext
]
=
as_benchmark_contexts
(
hidden_sizes
=
args
.
hidden_sizes
,
lora_ranks
=
args
.
lora_ranks
,
args
=
args
)
run
(
args
,
bench_contexts
)
def
run_range_bench
(
args
:
argparse
.
Namespace
):
print
(
args
)
hidden_sizes
=
list
(
range
(
args
.
hidden_sizes_start
,
args
.
hidden_sizes_end
+
1
,
args
.
hidden_sizes_increment
,
)
)
lora_ranks
=
list
(
range
(
args
.
lora_ranks_start
,
args
.
lora_ranks_end
+
1
,
args
.
lora_ranks_increment
)
)
print
(
f
"Range bench :
\n
Hidden Sizes
{
hidden_sizes
}
LoRA Ranks
{
lora_ranks
}
"
)
# Get all benchmarking contexts
bench_contexts
:
list
[
BenchmarkContext
]
=
as_benchmark_contexts
(
hidden_sizes
=
hidden_sizes
,
lora_ranks
=
lora_ranks
,
args
=
args
)
run
(
args
,
bench_contexts
)
def
run_model_bench
(
args
:
argparse
.
Namespace
):
print
(
args
)
def
hidden_sizes_from_model
(
model
:
str
,
tp_size
:
int
)
->
set
[
int
]:
hidden_sizes
=
set
()
for
KN
,
tp_split_dim
in
WEIGHT_SHAPES
[
model
]:
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
hidden_sizes
.
add
(
KN
[
1
])
return
hidden_sizes
# Get all hidden sizes
hidden_sizes
:
set
[
int
]
=
set
()
for
model_name
,
tp_size
in
product
(
args
.
models
,
args
.
tp_sizes
):
hidden_sizes
=
hidden_sizes
.
union
(
hidden_sizes_from_model
(
model_name
,
tp_size
))
print
(
f
"Model bench :
\n
Hidden Sizes
{
hidden_sizes
}
LoRA Ranks
{
args
.
lora_ranks
}
"
)
# Get all benchmarking contexts
bench_contexts
:
list
[
BenchmarkContext
]
=
as_benchmark_contexts
(
hidden_sizes
=
hidden_sizes
,
lora_ranks
=
args
.
lora_ranks
,
args
=
args
)
run
(
args
,
bench_contexts
)
if
__name__
==
"__main__"
:
def
to_torch_dtype
(
dt
):
if
dt
==
"torch.float16"
:
return
torch
.
float16
if
dt
==
"torch.bfloat16"
:
return
torch
.
bfloat16
raise
ValueError
(
"unsupported dtype"
)
def
get_bool
(
s
:
str
)
->
bool
:
return
s
.
lower
()
in
[
"true"
,
"1"
]
def
add_common_command_args
(
p
:
argparse
.
ArgumentParser
):
p
.
add_argument
(
"--dtype"
,
type
=
to_torch_dtype
,
required
=
True
,
help
=
"Available options are ['torch.float16', 'torch.bfloat16']"
,
)
p
.
add_argument
(
"--arg-pool-size"
,
type
=
int
,
default
=
32
,
help
=
"Run profiles with a pool of input/output/meta tensors instead"
"of simply reusing the same tensors for all runs. A bigger arg-pool"
"mitigates hardware caching effects during benchmarking."
,
)
p
.
add_argument
(
"--cuda-graph-nops"
,
type
=
int
,
help
=
(
"when set profiling is done using cudagraph, "
"with the given number of operations in a graph."
"Note that the measurement returned is the time "
"taken for N consecutive executions of the benchmarking "
"functions, where N is the value of this argument."
),
)
p
.
add_argument
(
"--num-loras"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_NUM_LORAS
)
p
.
add_argument
(
"--num-active-loras"
,
type
=
int
,
default
=
None
,
help
=
"Active LoRAs. When None, all LoRAs are active"
,
)
p
.
add_argument
(
"--sort-by-lora-id"
,
nargs
=
"+"
,
type
=
get_bool
,
default
=
DEFAULT_SORT_BY_LORA_IDS
,
)
p
.
add_argument
(
"--op-types"
,
nargs
=
"+"
,
type
=
OpType
.
from_str
,
default
=
list
(
OpType
)
)
p
.
add_argument
(
"--seq-lengths"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_SEQ_LENGTHS
)
p
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
p
.
add_argument
(
"--expand-fn-add-inputs"
,
nargs
=
"+"
,
type
=
get_bool
,
default
=
DEFAULT_EXPAND_FN_ADD_INPUTS
,
)
p
.
add_argument
(
"-o"
,
"--output-directory"
,
type
=
str
,
help
=
(
"Output directory to store a the list of benchmarking"
"TMeasurement objects as a pickle file"
),
)
p
.
add_argument
(
"--test-correctness"
,
action
=
"store_true"
,
help
=
(
"When enabled, the benchmarking functions are tested"
"for correctness before the actual benchmarking"
),
)
parser
=
FlexibleArgumentParser
(
description
=
f
"""
Benchmark LoRA kernels:
{
use_cuda_graph_recommendation
()
}
list_bench example:
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
model_bench example:
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
range_bench example:
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
"""
,
# noqa: E501
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
,
required
=
True
)
list_parser
=
subparsers
.
add_parser
(
"list_bench"
)
list_parser
.
add_argument
(
"--hidden-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_HIDDEN_SIZES
)
list_parser
.
add_argument
(
"--lora-ranks"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_LORA_RANKS
)
add_common_command_args
(
list_parser
)
list_parser
.
set_defaults
(
func
=
run_list_bench
)
range_parser
=
subparsers
.
add_parser
(
"range_bench"
)
range_parser
.
add_argument
(
"--hidden-sizes-start"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--hidden-sizes-end"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--hidden-sizes-increment"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--lora-ranks-start"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--lora-ranks-end"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--lora-ranks-increment"
,
type
=
int
,
required
=
True
)
add_common_command_args
(
range_parser
)
range_parser
.
set_defaults
(
func
=
run_range_bench
)
model_parser
=
subparsers
.
add_parser
(
"model_bench"
)
model_parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
model_parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
model_parser
.
add_argument
(
"--lora-ranks"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_LORA_RANKS
)
add_common_command_args
(
model_parser
)
model_parser
.
set_defaults
(
func
=
run_model_bench
)
args
=
parser
.
parse_args
()
args
.
func
(
args
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_machete.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
math
import
os
import
pickle
as
pkl
import
time
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
itertools
import
product
from
typing
import
Callable
,
Optional
import
pandas
as
pd
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
marlin_permute_scales
,
marlin_zero_points
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_rows
,
quantize_weights
,
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-3-8b"
,
"meta-llama/Llama-2-70b-hf"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
DEFAULT_TP_SIZES
=
[
1
]
NVTX_PROFILE
=
os
.
environ
.
get
(
"NVTX_PROFILE"
,
False
)
if
NVTX_PROFILE
:
import
nvtx
def
terse_type_name
(
dt
):
return
{
torch
.
bfloat16
:
"bf16"
,
torch
.
float16
:
"fp16"
,
torch
.
int8
:
"int8"
,
torch
.
float8_e4m3fn
:
"fp8"
,
torch
.
float
:
"float"
,
torch
.
int
:
"int"
,
}[
dt
]
@
dataclass
class
BenchmarkTensors
:
w_ref
:
torch
.
Tensor
a
:
torch
.
Tensor
w_q
:
torch
.
Tensor
group_size
:
Optional
[
int
]
wtype
:
ScalarType
w_g_s
:
torch
.
Tensor
w_g_zp
:
Optional
[
torch
.
Tensor
]
w_ch_s
:
Optional
[
torch
.
Tensor
]
w_tok_s
:
Optional
[
torch
.
Tensor
]
@
dataclass
class
TypeConfig
:
act_type
:
torch
.
dtype
weight_type
:
ScalarType
output_type
:
Optional
[
torch
.
dtype
]
group_scale_type
:
Optional
[
torch
.
dtype
]
group_zero_type
:
Optional
[
torch
.
dtype
]
channel_scale_type
:
Optional
[
torch
.
dtype
]
token_scale_type
:
Optional
[
torch
.
dtype
]
def
rand_data
(
shape
,
dtype
=
torch
.
float16
,
scale
=
1
):
if
dtype
.
is_floating_point
:
return
(
scale
*
torch
.
rand
(
shape
,
device
=
"cuda"
)
-
0.3
).
to
(
dtype
)
else
:
return
torch
.
randint
(
-
15
,
15
,
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
def
quantize_and_pack
(
atype
:
torch
.
dtype
,
w
:
torch
.
Tensor
,
wtype
:
ScalarType
,
stype
:
Optional
[
torch
.
dtype
],
group_size
:
Optional
[
int
],
zero_points
:
bool
=
False
,
):
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
w_ref
,
w_q
,
w_s
,
w_zp
=
quantize_weights
(
w
,
wtype
,
group_size
=
group_size
,
zero_points
=
zero_points
,
# to match how the kernel applies zps
ref_zero_points_after_scales
=
True
,
)
w_q
=
pack_rows
(
w_q
,
wtype
.
size_bits
,
*
w_q
.
shape
)
return
w_ref
,
w_q
,
w_s
,
w_zp
def
create_bench_tensors
(
shape
:
tuple
[
int
,
int
,
int
],
types
:
TypeConfig
,
group_size
:
Optional
[
int
]
)
->
list
[
BenchmarkTensors
]:
m
,
n
,
k
=
shape
# we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb
num_weights
=
math
.
ceil
(
2
*
50
*
1024
**
2
*
8
/
(
k
*
n
*
types
.
weight_type
.
size_bits
)
)
a
=
rand_data
((
m
,
k
),
types
.
act_type
,
scale
=
5
)
benchmark_tensors
:
list
[
BenchmarkTensors
]
=
[]
for
_
in
range
(
num_weights
):
w
=
rand_data
((
k
,
n
),
types
.
act_type
,
scale
=
5
)
if
types
.
group_scale_type
is
not
None
:
w
=
w
.
to
(
types
.
group_scale_type
)
if
w
.
dtype
.
itemsize
==
1
:
w
=
w
.
to
(
torch
.
float16
)
w_ref
,
w_q_packed
,
w_s
,
w_zp
=
quantize_and_pack
(
a
.
dtype
,
w
,
types
.
weight_type
,
types
.
group_scale_type
,
group_size
,
types
.
group_zero_type
is
not
None
,
)
if
not
a
.
dtype
.
is_floating_point
:
aiinfo
=
torch
.
iinfo
(
a
.
dtype
)
w_ref
=
w_ref
.
round
().
clamp
(
aiinfo
.
min
,
aiinfo
.
max
)
w_ref
=
w_ref
.
to
(
torch
.
float32
)
w_ch_s
=
(
None
if
types
.
channel_scale_type
is
None
else
rand_data
((
n
,),
types
.
channel_scale_type
)
)
w_tok_s
=
(
None
if
types
.
token_scale_type
is
None
else
rand_data
((
m
,),
types
.
token_scale_type
)
)
benchmark_tensors
.
append
(
BenchmarkTensors
(
w_ref
=
w_ref
,
a
=
a
,
w_q
=
w_q_packed
,
wtype
=
types
.
weight_type
,
w_g_s
=
w_s
,
w_g_zp
=
w_zp
,
group_size
=
group_size
,
w_ch_s
=
w_ch_s
,
w_tok_s
=
w_tok_s
,
)
)
return
benchmark_tensors
def
torch_matmul_f16_create_bench_fn
(
bt
:
BenchmarkTensors
)
->
Callable
:
a
=
bt
.
a
w
=
bt
.
w_ref
.
to
(
bt
.
a
.
dtype
)
# use float reference tensor
if
a
.
dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]:
a
=
a
.
to
(
torch
.
float16
)
w
=
w
.
to
(
torch
.
float16
)
return
lambda
:
torch
.
matmul
(
a
,
w
)
def
cutlass_scaled_mm_create_bench_fn
(
bt
:
BenchmarkTensors
)
->
Callable
:
if
bt
.
w_ch_s
is
not
None
and
bt
.
w_tok_s
is
not
None
:
scale_a
=
bt
.
w_tok_s
.
to
(
torch
.
float32
)
scale_b
=
bt
.
w_ch_s
.
to
(
torch
.
float32
)
else
:
scale_a
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
bt
.
a
.
device
)
scale_b
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
bt
.
a
.
device
)
w_col_major
=
bt
.
w_ref
.
to
(
bt
.
a
.
dtype
).
t
().
contiguous
().
t
()
return
lambda
:
ops
.
cutlass_scaled_mm
(
bt
.
a
,
w_col_major
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
float16
)
def
marlin_create_bench_fn
(
bt
:
BenchmarkTensors
)
->
Callable
:
device
=
bt
.
a
.
device
workspace
=
MarlinWorkspace
(
bt
.
w_ref
.
shape
[
1
],
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
if
bt
.
w_g_zp
is
None
:
w_zp
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
else
:
w_zp
=
marlin_zero_points
(
bt
.
w_g_zp
,
bt
.
w_ref
.
shape
[
0
],
bt
.
w_ref
.
shape
[
1
],
bt
.
wtype
.
size_bits
)
if
bt
.
group_size
is
None
:
w_s
=
torch
.
tensor
([],
device
=
"cuda"
,
dtype
=
torch
.
half
)
else
:
w_s
=
marlin_permute_scales
(
bt
.
w_g_s
,
bt
.
w_ref
.
shape
[
0
],
bt
.
w_ref
.
shape
[
1
],
bt
.
group_size
)
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
w_q
=
ops
.
gptq_marlin_repack
(
bt
.
w_q
,
sort_indices
,
bt
.
w_ref
.
shape
[
0
],
bt
.
w_ref
.
shape
[
1
],
bt
.
wtype
.
size_bits
)
if
bt
.
a
.
dtype
.
is_floating_point
:
assert
bt
.
w_ch_s
is
None
assert
bt
.
w_tok_s
is
None
assert
bt
.
group_size
is
not
None
fn
=
lambda
:
ops
.
gptq_marlin_gemm
(
a
=
bt
.
a
,
c
=
None
,
b_q_weight
=
w_q
,
b_scales
=
w_s
,
global_scale
=
None
,
b_zeros
=
w_zp
,
g_idx
=
g_idx
,
perm
=
sort_indices
,
workspace
=
workspace
.
scratch
,
b_q_type
=
bt
.
wtype
,
size_m
=
bt
.
a
.
shape
[
0
],
size_n
=
bt
.
w_ref
.
shape
[
1
],
size_k
=
bt
.
w_ref
.
shape
[
0
],
is_k_full
=
True
,
is_zp_float
=
False
,
)
else
:
assert
bt
.
a
.
dtype
==
torch
.
int8
assert
bt
.
wtype
==
scalar_types
.
uint4b8
if
bt
.
w_ch_s
is
not
None
:
s_ch
=
bt
.
w_ch_s
.
to
(
torch
.
float32
)
else
:
s_ch
=
torch
.
ones
(
bt
.
w_ref
.
shape
[
1
],
dtype
=
torch
.
float32
,
device
=
device
)
if
bt
.
w_tok_s
is
not
None
:
s_tok
=
bt
.
w_tok_s
.
to
(
torch
.
float32
)
else
:
s_tok
=
torch
.
ones
(
bt
.
a
.
shape
[
0
],
dtype
=
torch
.
float32
,
device
=
device
)
fn
=
lambda
:
ops
.
marlin_qqq_gemm
(
a
=
bt
.
a
,
b_q_weight
=
w_q
,
s_group
=
w_s
,
s_tok
=
s_tok
,
s_ch
=
s_ch
,
workspace
=
workspace
.
scratch
,
size_m
=
bt
.
a
.
shape
[
0
],
size_n
=
bt
.
w_ref
.
shape
[
1
],
size_k
=
bt
.
w_ref
.
shape
[
0
],
)
return
fn
def
machete_create_bench_fn
(
bt
:
BenchmarkTensors
,
out_type
=
torch
.
dtype
,
schedule
=
None
)
->
Callable
:
w_q
=
bt
.
w_q
.
t
().
contiguous
().
t
()
# make col major
w_q
=
ops
.
machete_prepack_B
(
w_q
,
bt
.
a
.
dtype
,
bt
.
wtype
,
None
if
bt
.
w_g_s
is
None
else
bt
.
w_g_s
.
dtype
)
w_g_zp
=
bt
.
w_g_zp
if
w_g_zp
is
not
None
:
w_g_zp
=
-
1
*
bt
.
w_g_s
*
(
w_g_zp
.
to
(
bt
.
w_g_s
.
dtype
))
return
lambda
:
ops
.
machete_mm
(
a
=
bt
.
a
,
b_q
=
w_q
,
b_type
=
bt
.
wtype
,
b_group_scales
=
bt
.
w_g_s
,
b_group_zeros
=
w_g_zp
,
b_group_size
=
bt
.
group_size
,
b_channel_scales
=
bt
.
w_ch_s
,
a_token_scales
=
bt
.
w_tok_s
,
out_type
=
out_type
,
schedule
=
schedule
,
)
# impl
# bench
def
bench_fns
(
label
:
str
,
sub_label
:
str
,
description
:
str
,
fns
:
list
[
Callable
]):
min_run_time
=
1
if
not
NVTX_PROFILE
else
0.1
res
=
TBenchmark
.
Timer
(
stmt
=
"""
for fn in fns:
fn()
"""
,
globals
=
{
"fns"
:
fns
},
label
=
label
,
sub_label
=
sub_label
,
description
=
description
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
if
NVTX_PROFILE
:
with
(
nvtx
.
annotate
(
"mm-bench"
),
nvtx
.
annotate
(
f
"
{
label
}
|
{
sub_label
}
|
{
description
}
"
),
):
fns
[
0
]()
return
res
_SWEEP_SCHEDULES_RESULTS
:
Optional
[
pd
.
DataFrame
]
=
None
_SWEEP_SCHEDULES_RESULTS_CSV
:
Optional
[
str
]
=
None
def
bench
(
types
:
TypeConfig
,
group_size
:
int
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
,
sweep_schedules
:
bool
=
True
,
)
->
list
[
TMeasurement
]:
benchmark_tensors
=
create_bench_tensors
((
m
,
n
,
k
),
types
,
group_size
)
sub_label
+=
f
", L=
{
len
(
benchmark_tensors
)
}
"
name_type_string
=
f
"W
{
types
.
weight_type
}
"
+
f
"-A
{
terse_type_name
(
types
.
act_type
)
}
"
if
types
.
group_scale_type
is
not
None
:
name_type_string
+=
f
"-GS
{
terse_type_name
(
types
.
group_scale_type
)
}
"
if
types
.
group_zero_type
is
not
None
:
name_type_string
+=
f
"-GZ
{
terse_type_name
(
types
.
group_zero_type
)
}
"
if
group_size
is
not
None
:
name_type_string
+=
f
"-G
{
group_size
}
"
if
types
.
channel_scale_type
is
not
None
:
name_type_string
+=
f
"-CS
{
terse_type_name
(
types
.
channel_scale_type
)
}
"
if
types
.
token_scale_type
is
not
None
:
name_type_string
+=
f
"-TS
{
terse_type_name
(
types
.
token_scale_type
)
}
"
timers
=
[]
# pytorch impl
timers
.
append
(
bench_fns
(
label
,
sub_label
,
"torch.matmul (fp16)"
,
[
torch_matmul_f16_create_bench_fn
(
bt
)
for
bt
in
benchmark_tensors
],
)
)
if
types
.
act_type
==
torch
.
int8
or
types
.
act_type
==
torch
.
float8_e4m3fn
:
timers
.
append
(
bench_fns
(
label
,
sub_label
,
f
"cutlass_scaled_mm (
{
terse_type_name
(
types
.
act_type
)
}
)"
,
[
cutlass_scaled_mm_create_bench_fn
(
bt
)
for
bt
in
benchmark_tensors
],
)
)
if
types
.
act_type
!=
torch
.
float8_e4m3fn
:
timers
.
append
(
bench_fns
(
label
,
sub_label
,
f
"marlin (
{
name_type_string
}
)"
,
[
marlin_create_bench_fn
(
bt
)
for
bt
in
benchmark_tensors
],
)
)
# machete
timers
.
append
(
bench_fns
(
label
,
sub_label
,
f
"machete (
{
name_type_string
}
)"
,
[
machete_create_bench_fn
(
bt
,
out_type
=
types
.
output_type
)
for
bt
in
benchmark_tensors
],
)
)
if
sweep_schedules
:
global
_SWEEP_SCHEDULES_RESULTS
print
(
"Finding best schedule for machete"
)
best
=
None
best_schedule
=
None
schedules
=
ops
.
machete_supported_schedules
(
a_type
=
types
.
act_type
,
b_type
=
types
.
weight_type
,
group_scales_type
=
types
.
group_scale_type
,
group_zeros_type
=
types
.
group_zero_type
,
token_scales_type
=
types
.
token_scale_type
,
channel_scales_type
=
types
.
channel_scale_type
,
out_type
=
types
.
output_type
,
)
if
schedules
is
None
or
len
(
schedules
)
==
0
:
raise
ValueError
(
"No schedules found to sweep"
)
for
schedule
in
reversed
(
schedules
):
schedule_M
=
int
(
schedule
.
split
(
"_"
)[
0
].
split
(
"x"
)[
1
])
# Prune known bad schedules
if
schedule_M
>=
2
*
max
(
m
,
16
)
or
schedule_M
<
m
//
4
:
continue
res
=
bench_fns
(
label
,
sub_label
,
"machete_best"
,
[
machete_create_bench_fn
(
bt
,
out_type
=
types
.
output_type
,
schedule
=
schedule
)
for
bt
in
benchmark_tensors
],
)
results_row
=
{
"M"
:
m
,
"K"
:
k
,
"N"
:
n
,
"group_size"
:
group_size
,
"schedule"
:
schedule
,
"median"
:
res
.
median
,
}
if
_SWEEP_SCHEDULES_RESULTS
is
None
:
_SWEEP_SCHEDULES_RESULTS
=
pd
.
DataFrame
(
columns
=
results_row
.
keys
())
_SWEEP_SCHEDULES_RESULTS
.
loc
[
len
(
_SWEEP_SCHEDULES_RESULTS
)]
=
results_row
print
(
f
"
{
res
.
median
:
5.5
}
"
,
schedule
)
if
not
best
or
res
.
median
<
best
.
median
:
best
=
res
best_schedule
=
schedule
print
(
"Best schedule:"
,
best_schedule
)
timers
.
append
(
best
)
return
timers
# runner
def
print_timers
(
timers
:
list
[
TMeasurement
]):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
def
run
(
args
,
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]])
->
Iterable
[
TMeasurement
]:
types
=
TypeConfig
(
act_type
=
args
.
act_type
,
weight_type
=
scalar_types
.
uint4b8
if
args
.
group_zero_type
is
None
else
scalar_types
.
uint4
,
output_type
=
args
.
out_type
,
group_scale_type
=
args
.
group_scale_type
,
group_zero_type
=
args
.
group_zero_type
,
channel_scale_type
=
args
.
channel_scale_type
,
token_scale_type
=
args
.
token_scale_type
,
)
results
:
list
[
TMeasurement
]
=
[]
for
m
,
k
,
n
in
MKNs
:
timers
=
bench
(
types
,
args
.
group_size
,
m
,
k
,
n
,
f
"
{
args
.
act_type
}
-gemm"
,
f
"MKN=(
{
m
}
x
{
k
}
x
{
n
}
)"
,
sweep_schedules
=
args
.
sweep_schedules
,
)
print_timers
(
timers
)
results
.
extend
(
timers
)
return
results
# output makers
def
make_output
(
data
:
list
[
TMeasurement
],
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]],
base_description
:
str
,
timestamp
=
None
,
):
print
(
f
"== All Results
{
base_description
}
===="
)
print_timers
(
data
)
# pickle all the results
timestamp
=
int
(
time
.
time
())
if
timestamp
is
None
else
timestamp
with
open
(
f
"
{
base_description
}
-
{
timestamp
}
.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
(
data
,
f
)
# argparse runners
def
run_square_bench
(
args
):
dim_sizes
=
list
(
range
(
args
.
dim_start
,
args
.
dim_end
+
1
,
args
.
dim_increment
))
MKNs
=
list
(
zip
(
dim_sizes
,
dim_sizes
,
dim_sizes
))
data
=
run
(
args
.
dtype
,
args
.
sweep_schedules
,
MKNs
)
make_output
(
data
,
MKNs
,
f
"square_bench-
{
args
.
dtype
}
"
)
def
run_range_bench
(
args
):
m_start
,
k_start
,
n_start
=
(
int
(
x
)
for
x
in
args
.
dim_start
.
split
(
","
))
m_end
,
k_end
,
n_end
=
(
int
(
x
)
for
x
in
args
.
dim_end
.
split
(
","
))
m_increment
,
k_increment
,
n_increment
=
(
int
(
x
)
for
x
in
args
.
dim_increment
.
split
(
","
)
)
Ms
=
list
(
range
(
m_start
,
m_end
+
1
,
m_increment
))
Ks
=
list
(
range
(
k_start
,
k_end
+
1
,
k_increment
))
Ns
=
list
(
range
(
n_start
,
n_end
+
1
,
n_increment
))
MKNs
=
list
(
product
(
Ms
,
Ks
,
Ns
))
data
=
run
(
args
.
dtype
,
args
.
sweep_schedules
,
MKNs
)
make_output
(
data
,
MKNs
,
f
"range_bench-
{
args
.
dtype
}
"
)
def
run_model_bench
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
def
model_shapes
(
model_name
:
str
,
tp_size
:
int
)
->
list
[
tuple
[
int
,
int
]]:
KNs
=
[]
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model_name
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
KNs
.
append
(
KN
)
return
KNs
model_bench_data
=
[]
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
for
model
,
tp_size
in
models_tps
:
Ms
=
args
.
batch_sizes
KNs
=
model_shapes
(
model
,
tp_size
)
MKNs
=
[]
for
m
in
Ms
:
for
k
,
n
in
KNs
:
MKNs
.
append
((
m
,
k
,
n
))
data
=
run
(
args
,
MKNs
)
model_bench_data
.
append
(
data
)
type_string
=
f
"
{
args
.
act_type
}
"
# Print all results
for
data
,
model_tp
in
zip
(
model_bench_data
,
models_tps
):
model
,
tp_size
=
model_tp
print
(
f
"== Results
{
type_string
}
{
model
}
-TP
{
tp_size
}
===="
)
print_timers
(
data
)
timestr
=
time
.
strftime
(
"%Y%m%d-%H%M%S"
)
all_results
=
[]
for
d
in
model_bench_data
:
all_results
.
extend
(
d
)
# pickle all data
with
open
(
f
"model_bench-
{
type_string
}
-
{
timestr
}
.pkl"
,
"wb"
)
as
f
:
args_dict
=
vars
(
args
)
args_dict
.
pop
(
"func"
)
pkl
.
dump
(
{
"args"
:
args_dict
,
"results"
:
all_results
,
},
f
,
)
if
__name__
==
"__main__"
:
def
to_torch_dtype
(
dt
):
return
{
"bfloat16"
:
torch
.
bfloat16
,
"float16"
:
torch
.
float16
,
"int8"
:
torch
.
int8
,
"float8_e4m3fn"
:
torch
.
float8_e4m3fn
,
"int"
:
torch
.
int
,
"float"
:
torch
.
float
,
}[
dt
]
class
ToTorchDtype
(
argparse
.
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
setattr
(
namespace
,
self
.
dest
,
to_torch_dtype
(
values
))
parser
=
FlexibleArgumentParser
(
description
=
"""
Benchmark Machete GEMM.
To run square GEMMs:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
"""
,
# noqa: E501
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
parser
.
add_argument
(
"--act-type"
,
action
=
ToTorchDtype
,
required
=
True
,
choices
=
[
"bfloat16"
,
"float16"
,
"int8"
,
"float8_e4m3fn"
],
)
parser
.
add_argument
(
"--group-scale-type"
,
action
=
ToTorchDtype
,
choices
=
[
"bfloat16"
,
"float16"
],
)
parser
.
add_argument
(
"--group-zero-type"
,
type
=
to_torch_dtype
,
choices
=
[
"bfloat16"
,
"float16"
],
)
parser
.
add_argument
(
"--channel-scale-type"
,
action
=
ToTorchDtype
,
choices
=
[
"float"
],
)
parser
.
add_argument
(
"--token-scale-type"
,
action
=
ToTorchDtype
,
choices
=
[
"float"
],
)
parser
.
add_argument
(
"--out-type"
,
action
=
ToTorchDtype
,
choices
=
[
"bfloat16"
,
"float16"
],
)
parser
.
add_argument
(
"--group-size"
,
type
=
int
,
help
=
"Available options are ['None', '-1', '128'], default=128"
,
default
=
128
,
)
parser
.
add_argument
(
"--sweep-schedules"
,
action
=
"store_true"
,
help
=
"Run a sweep over all supported schedules"
,
)
parser
.
add_argument
(
"--sweep-csv-out"
,
help
=
"CSV to store sweep results"
,
default
=
"sch_sweep_results.csv"
,
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
,
required
=
True
)
square_parser
=
subparsers
.
add_parser
(
"square_bench"
)
square_parser
.
add_argument
(
"--dim-start"
,
type
=
int
,
required
=
True
)
square_parser
.
add_argument
(
"--dim-end"
,
type
=
int
,
required
=
True
)
square_parser
.
add_argument
(
"--dim-increment"
,
type
=
int
,
required
=
True
)
square_parser
.
set_defaults
(
func
=
run_square_bench
)
range_parser
=
subparsers
.
add_parser
(
"range_bench"
)
range_parser
.
add_argument
(
"--dim-start"
,
type
=
str
,
required
=
True
,
help
=
"Start value for M,K,N as common separated list"
,
)
range_parser
.
add_argument
(
"--dim-end"
,
type
=
str
,
required
=
True
,
help
=
"End value (inclusive) for M,K,N as common separated list"
,
)
range_parser
.
add_argument
(
"--dim-increment"
,
type
=
str
,
required
=
True
,
help
=
"Increment value for M,K,N as common separated list"
,
)
range_parser
.
set_defaults
(
func
=
run_range_bench
)
model_parser
=
subparsers
.
add_parser
(
"model_bench"
)
model_parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
model_parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
model_parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
model_parser
.
set_defaults
(
func
=
run_model_bench
)
args
=
parser
.
parse_args
()
_SWEEP_SCHEDULES_RESULTS_CSV
=
args
.
sweep_csv_out
args
.
func
(
args
)
if
_SWEEP_SCHEDULES_RESULTS
is
not
None
:
_SWEEP_SCHEDULES_RESULTS
.
to_csv
(
_SWEEP_SCHEDULES_RESULTS_CSV
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_marlin.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.utils.benchmark
as
benchmark
from
benchmark_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
,
)
from
vllm.model_executor.layers.quantization.utils.allspark_utils
import
(
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
ALLSPARK_SUPPORTED_QUANT_TYPES
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
MARLIN_SUPPORTED_GROUP_SIZES
,
query_marlin_supported_quant_types
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
FP4_MARLIN_SUPPORTED_GROUP_SIZES
,
rand_marlin_weight_fp4_like
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
marlin_quant_fp8_torch
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
awq_marlin_quantize
,
marlin_quantize
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
marlin_24_quantize
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
gptq_quantize_weights
,
quantize_weights
,
sort_weights
,
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
def
bench_run
(
results
:
list
[
benchmark
.
Measurement
],
model
:
str
,
act_order
:
bool
,
is_k_full
:
bool
,
quant_type
:
ScalarType
,
group_size
:
int
,
size_m
:
int
,
size_k
:
int
,
size_n
:
int
,
):
label
=
"Quant Matmul"
sub_label
=
"{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})"
.
format
(
model
,
act_order
,
is_k_full
,
str
(
quant_type
),
group_size
,
size_m
,
size_k
,
size_n
)
print
(
f
"Testing:
{
sub_label
}
"
)
a
=
torch
.
randn
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
()
b
=
torch
.
rand
(
size_k
,
size_n
).
to
(
torch
.
half
).
cuda
()
has_zp
=
quant_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
if
act_order
and
(
group_size
==
-
1
or
group_size
==
size_k
or
has_zp
):
return
if
size_k
%
group_size
!=
0
:
return
marlin_24_supported
=
(
quant_type
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and
group_size
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
repack_supported
=
(
quant_type
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and
group_size
in
MARLIN_SUPPORTED_GROUP_SIZES
)
allspark_supported
=
(
quant_type
in
ALLSPARK_SUPPORTED_QUANT_TYPES
and
group_size
==
-
1
and
not
act_order
and
is_k_full
)
def
gen_marlin_params
():
# Marlin quant
marlin_g_idx
=
marlin_sort_indices
=
marlin_zp
=
marlin_s2
=
None
if
quant_type
==
scalar_types
.
float4_e2m1f
:
if
group_size
!=
16
or
act_order
:
return
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
=
rand_marlin_weight_fp4_like
(
b
.
T
,
group_size
)
elif
quant_type
==
scalar_types
.
float8_e4m3fn
:
if
group_size
not
in
[
-
1
,
128
]
or
act_order
:
return
marlin_w_ref
,
marlin_q_w
,
marlin_s
=
marlin_quant_fp8_torch
(
b
.
T
,
group_size
)
elif
group_size
==
16
:
return
elif
has_zp
:
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
=
awq_marlin_quantize
(
b
,
quant_type
,
group_size
)
else
:
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_g_idx
,
marlin_sort_indices
,
_
=
(
marlin_quantize
(
b
,
quant_type
,
group_size
,
act_order
)
)
return
(
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
,
marlin_zp
,
marlin_g_idx
,
marlin_sort_indices
,
)
def
gen_marlin_24_params
():
marlin_24_w_ref
=
marlin_24_q_w_comp
=
marlin_24_meta
=
marlin_24_s
=
None
if
marlin_24_supported
:
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
=
(
marlin_24_quantize
(
b
,
quant_type
,
group_size
)
)
return
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
def
gen_repack_params
():
q_w_gptq
=
None
repack_sort_indices
=
None
if
repack_supported
:
(
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
)
=
gptq_quantize_weights
(
b
,
quant_type
,
group_size
,
act_order
)
q_w_gptq
=
gptq_pack
(
q_w
,
quant_type
.
size_bits
,
size_k
,
size_n
)
# For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing
repack_sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b
.
device
)
if
act_order
:
(
q_w
,
g_idx
,
repack_sort_indices
)
=
sort_weights
(
q_w
,
g_idx
)
return
q_w_gptq
,
repack_sort_indices
def
gen_allspark_params
():
qw_reorder
=
s_reorder
=
zp_reorder
=
sm_count
=
sm_version
=
(
CUBLAS_M_THRESHOLD
)
=
None
nonlocal
allspark_supported
if
allspark_supported
:
properties
=
torch
.
cuda
.
get_device_properties
(
b
.
device
.
index
)
sm_count
=
properties
.
multi_processor_count
sm_version
=
properties
.
major
*
10
+
properties
.
minor
supported_arch
=
sm_version
>=
80
and
sm_version
<
90
allspark_supported
=
allspark_supported
and
supported_arch
if
supported_arch
:
w_ref
,
qw
,
s
,
zp
=
quantize_weights
(
b
,
quant_type
,
group_size
,
has_zp
)
qw
=
qw
.
to
(
torch
.
uint8
)
qw_reorder
,
s_reorder
,
zp_reorder
=
ops
.
allspark_repack_weight
(
qw
,
s
,
zp
,
has_zp
)
CUBLAS_M_THRESHOLD
=
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
return
(
qw_reorder
,
s_reorder
,
zp_reorder
,
sm_count
,
sm_version
,
CUBLAS_M_THRESHOLD
,
)
(
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
,
marlin_zp
,
marlin_g_idx
,
marlin_sort_indices
,
)
=
gen_marlin_params
()
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
=
(
gen_marlin_24_params
()
)
q_w_gptq
,
repack_sort_indices
=
gen_repack_params
()
qw_reorder
,
s_reorder
,
zp_reorder
,
sm_count
,
sm_version
,
CUBLAS_M_THRESHOLD
=
(
gen_allspark_params
()
)
# Prepare
marlin_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
marlin_24_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
)
globals
=
{
# Gen params
"quant_type"
:
quant_type
,
"group_size"
:
group_size
,
"size_m"
:
size_m
,
"size_n"
:
size_n
,
"size_k"
:
size_k
,
"a"
:
a
,
# Marlin params
"marlin_w_ref"
:
marlin_w_ref
,
"marlin_q_w"
:
marlin_q_w
,
"marlin_s"
:
marlin_s
,
"marlin_s2"
:
marlin_s2
,
"marlin_zp"
:
marlin_zp
,
"marlin_g_idx"
:
marlin_g_idx
,
"marlin_sort_indices"
:
marlin_sort_indices
,
"marlin_workspace"
:
marlin_workspace
,
"is_k_full"
:
is_k_full
,
# Marlin_24 params
"marlin_24_w_ref"
:
marlin_24_w_ref
,
"marlin_24_q_w_comp"
:
marlin_24_q_w_comp
,
"marlin_24_meta"
:
marlin_24_meta
,
"marlin_24_s"
:
marlin_24_s
,
"marlin_24_workspace"
:
marlin_24_workspace
,
# GPTQ params
"q_w_gptq"
:
q_w_gptq
,
"repack_sort_indices"
:
repack_sort_indices
,
# AllSpark W8A16 params
"qw_reorder"
:
qw_reorder
,
"s_reorder"
:
s_reorder
,
"zp_reorder"
:
zp_reorder
,
"sm_count"
:
sm_count
,
"sm_version"
:
sm_version
,
"CUBLAS_M_THRESHOLD"
:
CUBLAS_M_THRESHOLD
,
# Kernels
"gptq_marlin_gemm"
:
ops
.
gptq_marlin_gemm
,
"gptq_marlin_24_gemm"
:
ops
.
gptq_marlin_24_gemm
,
"gptq_marlin_repack"
:
ops
.
gptq_marlin_repack
,
"allspark_w8a16_gemm"
:
ops
.
allspark_w8a16_gemm
,
}
min_run_time
=
1
# Warmup pytorch
for
_
in
range
(
5
):
torch
.
matmul
(
a
,
marlin_w_ref
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"torch.matmul(a, marlin_w_ref)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"pytorch_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_gemm_fp32"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
if
marlin_24_supported
:
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_24_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
if
repack_supported
:
results
.
append
(
benchmark
.
Timer
(
stmt
=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_repack"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
if
allspark_supported
:
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"allspark_w8a16_gemm_fp32"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
def
main
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
results
:
list
[
benchmark
.
Measurement
]
=
[]
for
model
in
args
.
models
:
for
layer
in
WEIGHT_SHAPES
[
model
]:
size_k
=
layer
[
0
]
size_n
=
layer
[
1
]
if
len
(
args
.
limit_k
)
>
0
and
size_k
not
in
args
.
limit_k
:
continue
if
len
(
args
.
limit_n
)
>
0
and
size_n
not
in
args
.
limit_n
:
continue
for
act_order
in
ACT_ORDER_OPTS
:
if
(
len
(
args
.
limit_act_order
)
>
0
and
act_order
not
in
args
.
limit_act_order
):
continue
for
is_k_full
in
K_FULL_OPTS
:
if
(
len
(
args
.
limit_k_full
)
>
0
and
is_k_full
not
in
args
.
limit_k_full
):
continue
for
quant_type
in
query_marlin_supported_quant_types
():
if
(
len
(
args
.
limit_num_bits
)
>
0
and
quant_type
.
size_bits
not
in
args
.
limit_num_bits
):
continue
for
group_size
in
(
MARLIN_SUPPORTED_GROUP_SIZES
+
FP4_MARLIN_SUPPORTED_GROUP_SIZES
):
if
(
len
(
args
.
limit_group_size
)
>
0
and
group_size
not
in
args
.
limit_group_size
):
continue
# For act_order, the group_size must be less than
# size_k
if
act_order
and
(
group_size
==
size_k
or
group_size
==
-
1
):
continue
for
size_m
in
args
.
batch_sizes
:
bench_run
(
results
,
model
,
act_order
,
is_k_full
,
quant_type
,
group_size
,
size_m
,
size_k
,
size_n
,
)
compare
=
benchmark
.
Compare
(
results
)
compare
.
print
()
# For quick benchmarking use:
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
#
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark Marlin across specified models/shapes/batches"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-group-size"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-num-bits"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-act-order"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-k-full"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
args
=
parser
.
parse_args
()
main
(
args
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_moe.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
json
import
time
from
contextlib
import
nullcontext
from
datetime
import
datetime
from
itertools
import
product
from
typing
import
Any
,
TypedDict
import
ray
import
torch
from
ray.experimental.tqdm_ray
import
tqdm
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_config
from
vllm.triton_utils
import
triton
from
vllm.utils
import
FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
class
BenchmarkConfig
(
TypedDict
):
BLOCK_SIZE_M
:
int
BLOCK_SIZE_N
:
int
BLOCK_SIZE_K
:
int
GROUP_SIZE_M
:
int
num_warps
:
int
num_stages
:
int
def
benchmark_config
(
config
:
BenchmarkConfig
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
block_quant_shape
:
list
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
)
->
float
:
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
if
use_int8_w8a16
:
w1
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
),
dtype
=
torch
.
int8
,
)
w2
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
),
dtype
=
torch
.
int8
,
)
else
:
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
)
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
w1_scale
=
None
w2_scale
=
None
a1_scale
=
None
a2_scale
=
None
if
use_int8_w8a16
:
w1_scale
=
torch
.
randn
(
(
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
if
use_fp8_w8a8
:
if
block_quant_shape
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
E
=
num_experts
N
=
shard_intermediate_size
//
2
K
=
hidden_size
factor_for_scale
=
1e-2
n_tiles_w1
=
(
2
*
N
+
block_n
-
1
)
//
block_n
n_tiles_w2
=
(
K
+
block_n
-
1
)
//
block_n
k_tiles_w1
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles_w2
=
(
N
+
block_k
-
1
)
//
block_k
w1_scale
=
(
torch
.
rand
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
)
*
factor_for_scale
)
w2_scale
=
(
torch
.
rand
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
)
*
factor_for_scale
)
else
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
w1
=
w1
.
to
(
FP8_DTYPE
)
w2
=
w2
.
to
(
FP8_DTYPE
)
input_gating
=
torch
.
empty
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
from
vllm.model_executor.layers.fused_moe
import
override_config
with
override_config
(
config
):
if
use_deep_gemm
:
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
x
,
input_gating
,
topk
,
False
)
return
fused_experts
(
x
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_quant_shape
,
allow_deep_gemm
=
True
,
)
else
:
fused_moe
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_quant_shape
,
)
# JIT compilation & warmup
run
()
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
()
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
prepare
(
i
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
def
get_rocm_tuning_space
(
use_fp16
):
block_mn_range
=
[
16
,
32
,
64
,
128
,
256
]
block_k_range
=
[
16
,
32
,
64
,
128
,
256
]
if
not
use_fp16
:
block_k_range
.
remove
(
16
)
# BLOCK_K=16 not supported for fp8
num_warps_range
=
[
1
,
2
,
4
,
8
]
group_m_range
=
[
1
,
4
,
8
,
16
,
32
]
num_stage_range
=
[
2
]
waves_per_eu_range
=
[
0
]
matrix_instr_nonkdim_range
=
[
16
,
32
]
if
use_fp16
else
[]
kpack_range
=
[
1
,
2
]
if
use_fp16
else
[]
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_mn_range
,
"BLOCK_SIZE_N"
:
block_mn_range
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stage_range
,
"waves_per_eu"
:
waves_per_eu_range
,
}
if
use_fp16
:
param_ranges
[
"matrix_instr_nonkdim"
]
=
matrix_instr_nonkdim_range
param_ranges
[
"kpack"
]
=
kpack_range
return
param_ranges
def
get_configs_compute_bound
(
use_fp16
,
block_quant_shape
)
->
list
[
dict
[
str
,
int
]]:
configs
:
list
[
BenchmarkConfig
]
=
[]
if
current_platform
.
is_rocm
():
param_ranges
=
get_rocm_tuning_space
(
use_fp16
)
else
:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
block_m_range
=
[
16
,
32
,
64
,
128
,
256
]
block_n_range
=
[
32
,
64
,
128
,
256
]
block_k_range
=
[
64
,
128
,
256
]
num_warps_range
=
[
4
,
8
]
group_m_range
=
[
1
,
16
,
32
,
64
]
num_stage_range
=
[
2
,
3
,
4
,
5
]
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_m_range
,
"BLOCK_SIZE_N"
:
block_n_range
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stage_range
,
}
keys
,
values
=
zip
(
*
param_ranges
.
items
())
for
config_values
in
product
(
*
values
):
config
=
dict
(
zip
(
keys
,
config_values
))
configs
.
append
(
config
)
# Remove configs that are not compatible with fp8 block quantization
# BLOCK_SIZE_K must be a multiple of block_k
# BLOCK_SIZE_N must be a multiple of block_n
if
block_quant_shape
is
not
None
and
not
use_fp16
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
for
config
in
configs
[:]:
if
(
config
[
"BLOCK_SIZE_K"
]
%
block_k
!=
0
or
config
[
"BLOCK_SIZE_N"
]
%
block_n
!=
0
):
configs
.
remove
(
config
)
return
configs
def
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
,
topk
):
N1
,
K1
=
shard_intermediate_size
,
hidden_size
N2
,
K2
=
hidden_size
,
shard_intermediate_size
//
2
pruned_space_1
=
prune_rocm_configs
(
num_tokens
*
topk
,
N1
,
K1
,
search_space
,
is_fp16
)
pruned_space_2
=
prune_rocm_configs
(
num_tokens
*
topk
,
N2
,
K2
,
search_space
,
is_fp16
)
search_space
=
merge_unique_dicts
(
pruned_space_1
,
pruned_space_2
)
return
search_space
# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def
prune_rocm_configs
(
M
,
N
,
K
,
configs
,
is_fp16
=
True
):
pruned_configs
=
[]
elemBytes_a
=
2
if
is_fp16
else
1
elemBytes_b
=
2
if
is_fp16
else
1
mfma
=
16
if
M
<
32
or
N
<
32
else
32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm
=
False
if
M
>=
2048
and
N
>=
2048
:
large_gemm
=
True
for
config
in
configs
:
BLOCK_SIZE_M
=
config
.
get
(
"BLOCK_SIZE_M"
)
BLOCK_SIZE_N
=
config
.
get
(
"BLOCK_SIZE_N"
)
BLOCK_SIZE_K
=
config
.
get
(
"BLOCK_SIZE_K"
)
num_warps
=
config
.
get
(
"num_warps"
)
if
is_fp16
:
matrix_instr_nonkdim
=
config
.
get
(
"matrix_instr_nonkdim"
)
if
matrix_instr_nonkdim
>
mfma
:
continue
if
mfma
==
4
and
BLOCK_SIZE_K
<
64
:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if
BLOCK_SIZE_M
*
BLOCK_SIZE_N
<
64
:
continue
SPLIT_K
=
config
.
get
(
"SPLIT_K"
,
1
)
GROUP_M
=
config
.
get
(
"GROUP_SIZE_M"
)
if
is_fp16
:
if
(
matrix_instr_nonkdim
>
BLOCK_SIZE_M
or
matrix_instr_nonkdim
>
BLOCK_SIZE_N
):
continue
if
matrix_instr_nonkdim
>=
M
and
matrix_instr_nonkdim
!=
BLOCK_SIZE_M
:
continue
if
matrix_instr_nonkdim
>=
N
and
matrix_instr_nonkdim
!=
BLOCK_SIZE_N
:
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if
M
*
2
<
BLOCK_SIZE_M
and
BLOCK_SIZE_M
!=
16
:
continue
if
N
*
2
<
BLOCK_SIZE_N
and
BLOCK_SIZE_N
!=
16
:
continue
# skip large split_k when not necessary
if
SPLIT_K
!=
1
and
not
need_split_k
(
M
,
N
,
K
):
continue
# skip split_k that leads to EVEN_K = false
leap
=
SPLIT_K
*
BLOCK_SIZE_K
modv
=
K
%
leap
if
modv
!=
0
:
continue
# skip large GROUP_M
if
GROUP_M
*
BLOCK_SIZE_M
>
M
and
GROUP_M
!=
1
:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS
=
(
BLOCK_SIZE_K
*
BLOCK_SIZE_M
*
elemBytes_a
+
BLOCK_SIZE_K
*
BLOCK_SIZE_N
*
elemBytes_b
)
if
LDS
>
65536
:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if
large_gemm
:
if
BLOCK_SIZE_M
<
64
or
BLOCK_SIZE_N
<
64
:
continue
if
BLOCK_SIZE_K
<
64
:
continue
if
num_warps
<
4
:
continue
pruned_configs
.
append
(
config
)
return
pruned_configs
def
need_split_k
(
SIZE_M
,
SIZE_N
,
SIZE_K
):
return
(
SIZE_M
<
64
or
SIZE_N
<
64
)
and
SIZE_K
>
1024
def
merge_unique_dicts
(
list1
,
list2
):
result
=
[]
combined_list
=
list1
.
copy
()
combined_list
.
extend
(
list2
)
for
dictionary
in
combined_list
:
if
dictionary
not
in
result
:
result
.
append
(
dictionary
)
return
result
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
seed
)
self
.
seed
=
seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self
.
device_id
=
int
(
ray
.
get_gpu_ids
()[
0
])
def
benchmark
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
list
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
current_platform
.
seed_everything
(
self
.
seed
)
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config
=
get_moe_configs
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
)
if
op_config
is
None
:
config
=
get_default_config
(
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype_str
,
is_marlin
=
False
,
)
else
:
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
block_quant_shape
=
block_quant_shape
,
use_deep_gemm
=
use_deep_gemm
,
)
return
config
,
kernel_time
def
tune
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
search_space
:
list
[
dict
[
str
,
int
]],
block_quant_shape
:
list
[
int
],
use_deep_gemm
:
bool
,
)
->
dict
[
str
,
int
]:
best_config
=
None
best_time
=
float
(
"inf"
)
if
current_platform
.
is_rocm
():
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
)
search_space
=
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
,
topk
,
)
need_device_guard
=
False
if
current_platform
.
is_rocm
():
visible_device
=
os
.
environ
.
get
(
"ROCR_VISIBLE_DEVICES"
,
None
)
if
visible_device
!=
f
"
{
self
.
device_id
}
"
:
need_device_guard
=
True
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
need_device_guard
else
nullcontext
():
for
config
in
tqdm
(
search_space
):
try
:
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
20
,
block_quant_shape
=
block_quant_shape
,
use_deep_gemm
=
use_deep_gemm
,
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
continue
if
kernel_time
<
best_time
:
best_time
=
kernel_time
best_config
=
config
now
=
datetime
.
now
()
print
(
f
"
{
now
.
ctime
()
}
] Completed tuning for batch_size=
{
num_tokens
}
"
)
assert
best_config
is
not
None
return
best_config
def
sort_config
(
config
:
BenchmarkConfig
)
->
BenchmarkConfig
:
return
{
"BLOCK_SIZE_M"
:
config
[
"BLOCK_SIZE_M"
],
"BLOCK_SIZE_N"
:
config
[
"BLOCK_SIZE_N"
],
"BLOCK_SIZE_K"
:
config
[
"BLOCK_SIZE_K"
],
"GROUP_SIZE_M"
:
config
[
"GROUP_SIZE_M"
],
"num_warps"
:
config
[
"num_warps"
],
"num_stages"
:
config
[
"num_stages"
],
**
(
{
"waves_per_eu"
:
config
[
"waves_per_eu"
]}
if
"waves_per_eu"
in
config
else
{}
),
**
(
{
"matrix_instr_nonkdim"
:
config
[
"matrix_instr_nonkdim"
]}
if
"matrix_instr_nonkdim"
in
config
else
{}
),
**
({
"kpack"
:
config
[
"kpack"
]}
if
"kpack"
in
config
else
{}),
}
def
save_configs
(
configs
:
dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
list
[
int
],
)
->
None
:
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename
=
get_config_file_name
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
,
block_quant_shape
)
print
(
f
"Writing best config to
{
filename
}
..."
)
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
configs
,
f
,
indent
=
4
)
f
.
write
(
"
\n
"
)
def
get_weight_block_size_safety
(
config
,
default_value
=
None
):
quantization_config
=
getattr
(
config
,
"quantization_config"
,
{})
if
isinstance
(
quantization_config
,
dict
):
return
quantization_config
.
get
(
"weight_block_size"
,
default_value
)
return
default_value
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
config
=
get_config
(
model
=
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
model_prefix
:
config
=
getattr
(
config
,
args
.
model_prefix
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
elif
config
.
architectures
[
0
]
in
(
"DeepseekV3ForCausalLM"
,
"DeepseekV2ForCausalLM"
):
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
elif
config
.
architectures
[
0
]
in
(
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
):
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
else
:
# Support for llama4
config
=
config
.
get_text_config
()
# Default: Mixtral.
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
block_quant_shape
=
get_weight_block_size_safety
(
config
)
if
args
.
batch_size
is
None
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
,
]
else
:
batch_sizes
=
args
.
batch_size
use_deep_gemm
=
bool
(
args
.
use_deep_gemm
)
if
current_platform
.
is_rocm
()
and
"HIP_VISIBLE_DEVICES"
in
os
.
environ
:
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
logger
.
warning
(
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
)
val
=
os
.
environ
[
"HIP_VISIBLE_DEVICES"
]
os
.
environ
[
"ROCR_VISIBLE_DEVICES"
]
=
val
del
os
.
environ
[
"HIP_VISIBLE_DEVICES"
]
ray
.
init
()
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
def
_distribute
(
method
:
str
,
inputs
:
list
[
Any
])
->
list
[
Any
]:
outputs
=
[]
worker_idx
=
0
for
input_args
in
inputs
:
worker
=
workers
[
worker_idx
]
worker_method
=
getattr
(
worker
,
method
)
output
=
worker_method
.
remote
(
*
input_args
)
outputs
.
append
(
output
)
worker_idx
=
(
worker_idx
+
1
)
%
num_gpus
return
ray
.
get
(
outputs
)
if
args
.
tune
:
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
)
search_space
=
get_configs_compute_bound
(
is_fp16
,
block_quant_shape
)
print
(
f
"Start tuning over
{
len
(
search_space
)
}
configurations..."
)
start
=
time
.
time
()
configs
=
_distribute
(
"tune"
,
[
(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
search_space
,
block_quant_shape
,
use_deep_gemm
,
)
for
batch_size
in
batch_sizes
],
)
best_configs
=
{
M
:
sort_config
(
config
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
}
save_configs
(
best_configs
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
,
)
end
=
time
.
time
()
print
(
f
"Tuning took
{
end
-
start
:.
2
f
}
seconds"
)
else
:
outputs
=
_distribute
(
"benchmark"
,
[
(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
,
use_deep_gemm
,
)
for
batch_size
in
batch_sizes
],
)
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
print
(
f
"Batch size:
{
batch_size
}
, config:
{
config
}
"
)
print
(
f
"Kernel time:
{
kernel_time
:.
2
f
}
us"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
"--tensor-parallel-size"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--use-deep-gemm"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
"+"
,
required
=
False
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--model-prefix"
,
type
=
str
,
required
=
False
)
args
=
parser
.
parse_args
()
main
(
args
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_moe_align_block_size.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
itertools
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size_triton
,
)
from
vllm.triton_utils
import
triton
def
get_topk_ids
(
num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
)
->
torch
.
Tensor
:
return
torch
.
stack
(
[
torch
.
randperm
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[:
topk
]
for
_
in
range
(
num_tokens
)
]
)
def
check_correctness
(
num_tokens
,
num_experts
=
256
,
block_size
=
256
,
topk
=
8
):
"""
Verifies vllm vs. Triton
"""
topk_ids
=
get_topk_ids
(
num_tokens
,
num_experts
,
topk
)
# 1. malloc space for triton and vllm
# malloc enough space (max_num_tokens_padded) for the sorted ids
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids_triton
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids_triton
.
fill_
(
topk_ids
.
numel
())
# fill with sentinel value
expert_ids_triton
=
torch
.
zeros
(
(
max_num_tokens_padded
//
block_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
num_tokens_post_pad_triton
=
torch
.
empty
((
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids_vllm
=
torch
.
empty_like
(
sorted_ids_triton
)
sorted_ids_vllm
.
fill_
(
topk_ids
.
numel
())
expert_ids_vllm
=
torch
.
zeros_like
(
expert_ids_triton
)
num_tokens_post_pad_vllm
=
torch
.
empty_like
(
num_tokens_post_pad_triton
)
# 2. run implementations
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_triton
,
expert_ids_triton
,
num_tokens_post_pad_triton
,
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_vllm
,
expert_ids_vllm
,
num_tokens_post_pad_vllm
,
)
print
(
f
"✅ VLLM implementation works with
{
num_experts
}
experts!"
)
# 3. compare results
if
torch
.
allclose
(
expert_ids_triton
,
expert_ids_vllm
)
and
torch
.
allclose
(
num_tokens_post_pad_triton
,
num_tokens_post_pad_vllm
):
print
(
"✅ Triton and VLLM implementations match."
)
else
:
print
(
"❌ Triton and VLLM implementations DO NOT match."
)
print
(
"Triton expert_ids:"
,
expert_ids_triton
)
print
(
"VLLM expert_ids:"
,
expert_ids_vllm
)
print
(
"Triton num_tokens_post_pad:"
,
num_tokens_post_pad_triton
)
print
(
"VLLM num_tokens_post_pad:"
,
num_tokens_post_pad_vllm
)
# test configurations
num_tokens_range
=
[
1
,
16
,
256
,
4096
]
num_experts_range
=
[
16
,
64
,
224
,
256
,
280
,
512
]
topk_range
=
[
1
,
2
,
8
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"triton"
],
# "triton"
line_names
=
[
"VLLM"
,
"Triton"
],
# "Triton"
plot_name
=
"moe-align-block-size-performance"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
num_experts
,
topk
,
provider
):
"""Benchmark function for Triton."""
block_size
=
256
topk_ids
=
get_topk_ids
(
num_tokens
,
num_experts
,
topk
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
num_tokens_post_pad
=
torch
.
empty
((
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"vllm"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
),
quantiles
=
quantiles
,
)
elif
provider
==
"triton"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num_experts"
,
type
=
int
,
default
=
64
,
choices
=
[
8
,
16
,
32
,
64
,
128
,
256
],
)
parser
.
add_argument
(
"--topk"
,
type
=
int
,
default
=
8
,
choices
=
[
2
,
4
,
8
],
help
=
"Top-k value for correctness check."
,
)
args
=
parser
.
parse_args
()
print
(
"Running correctness check..."
)
check_correctness
(
num_tokens
=
1024
,
num_experts
=
args
.
num_experts
,
topk
=
args
.
topk
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_moe_permute_unpermute.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
from
typing
import
Any
,
TypedDict
import
ray
import
torch
from
transformers
import
AutoConfig
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_moe_permute
,
_moe_unpermute_and_reduce
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
*
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
class
BenchmarkConfig
(
TypedDict
):
BLOCK_SIZE_M
:
int
BLOCK_SIZE_N
:
int
BLOCK_SIZE_K
:
int
GROUP_SIZE_M
:
int
num_warps
:
int
num_stages
:
int
def
benchmark_permute
(
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
use_customized_permute
:
bool
=
False
,
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
# output_hidden_states = torch.empty_like(hidden_states)
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
align_block_size
=
None
qhidden_states
=
hidden_states
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
qhidden_states
,
input_gating
,
topk
,
False
)
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
(
moe_permute
(
qhidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
token_expert_indices
=
token_expert_indices
,
topk
=
topk
,
n_expert
=
num_experts
,
n_local_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
)
else
:
(
permuted_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
_moe_permute
(
qhidden_states
,
None
,
topk_ids
,
num_experts
,
None
,
align_block_size
)
# JIT compilation & warmup
run
()
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
()
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
prepare
(
i
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
def
benchmark_unpermute
(
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
use_customized_permute
:
bool
=
False
,
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
output_hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
align_block_size
=
None
qhidden_states
=
hidden_states
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
qhidden_states
,
input_gating
,
topk
,
False
)
def
prepare
():
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
(
moe_permute
(
qhidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
token_expert_indices
=
token_expert_indices
,
topk
=
topk
,
n_expert
=
num_experts
,
n_local_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
)
# convert to fp16/bf16 as gemm output
return
(
permuted_hidden_states
.
to
(
dtype
),
first_token_off
,
inv_perm_idx
,
m_indices
,
)
else
:
(
permuted_qhidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
_moe_permute
(
qhidden_states
,
None
,
topk_ids
,
num_experts
,
None
,
align_block_size
)
# convert to fp16/bf16 as gemm output
return
(
permuted_qhidden_states
.
to
(
dtype
),
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
def
run
(
input
:
tuple
):
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
input
moe_unpermute
(
permuted_hidden_states
,
topk_weights
,
topk_ids
,
inv_perm_idx
,
first_token_off
,
topk
,
num_experts
,
num_experts
,
)
else
:
(
permuted_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
input
_moe_unpermute_and_reduce
(
output_hidden_states
,
permuted_hidden_states
,
inv_perm
,
topk_weights
)
# JIT compilation & warmup
input
=
prepare
()
run
(
input
)
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
(
input
)
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
seed
)
self
.
seed
=
seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self
.
device_id
=
int
(
ray
.
get_gpu_ids
()[
0
])
def
benchmark
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_customized_permute
:
bool
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
current_platform
.
seed_everything
(
self
.
seed
)
permute_time
=
benchmark_permute
(
num_tokens
,
num_experts
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
use_customized_permute
=
use_customized_permute
,
)
unpermute_time
=
benchmark_unpermute
(
num_tokens
,
num_experts
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
use_customized_permute
=
use_customized_permute
,
)
return
permute_time
,
unpermute_time
def
get_weight_block_size_safety
(
config
,
default_value
=
None
):
quantization_config
=
getattr
(
config
,
"quantization_config"
,
{})
if
isinstance
(
quantization_config
,
dict
):
return
quantization_config
.
get
(
"weight_block_size"
,
default_value
)
return
default_value
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
elif
(
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
):
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
elif
config
.
architectures
[
0
]
in
[
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
]:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
else
:
# Support for llama4
config
=
config
.
get_text_config
()
# Default: Mixtral.
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_customized_permute
=
args
.
use_customized_permute
if
args
.
batch_size
is
None
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
,
]
else
:
batch_sizes
=
[
args
.
batch_size
]
ray
.
init
()
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
def
_distribute
(
method
:
str
,
inputs
:
list
[
Any
])
->
list
[
Any
]:
outputs
=
[]
worker_idx
=
0
for
input_args
in
inputs
:
worker
=
workers
[
worker_idx
]
worker_method
=
getattr
(
worker
,
method
)
output
=
worker_method
.
remote
(
*
input_args
)
outputs
.
append
(
output
)
worker_idx
=
(
worker_idx
+
1
)
%
num_gpus
return
ray
.
get
(
outputs
)
outputs
=
_distribute
(
"benchmark"
,
[
(
batch_size
,
E
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_customized_permute
,
)
for
batch_size
in
batch_sizes
],
)
for
batch_size
,
(
permute
,
unpermute
)
in
zip
(
batch_sizes
,
outputs
):
print
(
f
"Batch size:
{
batch_size
}
"
)
print
(
f
"Permute time:
{
permute
:.
2
f
}
us"
)
print
(
f
"Unpermute time:
{
unpermute
:.
2
f
}
us"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--use-customized-permute"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
main
(
args
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_paged_attention.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
time
from
typing
import
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
create_kv_caches_with_random
,
)
logger
=
init_logger
(
__name__
)
NUM_BLOCKS
=
128
*
1024
PARTITION_SIZE
=
512
PARTITION_SIZE_ROCM
=
256
@
torch
.
inference_mode
()
def
main
(
version
:
str
,
num_seqs
:
int
,
seq_len
:
int
,
num_query_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
do_profile
:
bool
,
device
:
str
=
"cuda"
,
kv_cache_dtype
:
Optional
[
str
]
=
None
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
alibi_slopes
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
device
)
seq_lens
=
[
seq_len
for
_
in
range
(
num_seqs
)]
max_seq_len
=
max
(
seq_lens
)
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables_lst
:
list
[
list
[
int
]]
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables_lst
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables_lst
,
dtype
=
torch
.
int
,
device
=
device
)
# Create the KV cache.
key_caches
,
value_caches
=
create_kv_caches_with_random
(
NUM_BLOCKS
,
block_size
,
1
,
num_kv_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
device
=
device
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Prepare for the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
if
version
==
"v2"
:
if
current_platform
.
is_rocm
():
global
PARTITION_SIZE
if
not
args
.
custom_paged_attn
and
not
current_platform
.
is_navi
():
PARTITION_SIZE
=
1024
else
:
PARTITION_SIZE
=
PARTITION_SIZE_ROCM
num_partitions
=
(
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
def
run_cuda_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
perf_counter
()
# Using default kv_scale
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
for
_
in
range
(
num_iters
):
if
version
==
"v1"
:
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
elif
version
==
"v2"
:
if
not
args
.
custom_paged_attn
:
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
ops
.
paged_attention_rocm
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
None
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
raise
ValueError
(
f
"Invalid version:
{
version
}
"
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
print
(
"Warming up..."
)
run_benchmark
=
run_cuda_benchmark
run_benchmark
(
num_iters
=
3
,
profile
=
False
)
# Benchmark.
if
do_profile
:
latency
=
run_benchmark
(
num_iters
=
1
,
profile
=
True
)
else
:
latency
=
run_benchmark
(
num_iters
=
100
,
profile
=
False
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
"__main__"
:
logger
.
warning
(
"This script benchmarks the paged attention kernel. "
"By default this is no longer used in vLLM inference."
)
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the paged attention kernel."
)
parser
.
add_argument
(
"--version"
,
type
=
str
,
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--use-alibi"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--kv-cache-dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8"
,
"fp8_e5m2"
,
"fp8_e4m3"
],
default
=
"auto"
,
help
=
"Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)"
,
)
parser
.
add_argument
(
"--custom-paged-attn"
,
action
=
"store_true"
,
help
=
"Use custom paged attention"
)
args
=
parser
.
parse_args
()
print
(
args
)
if
args
.
num_query_heads
%
args
.
num_kv_heads
!=
0
:
raise
ValueError
(
"num_query_heads must be divisible by num_kv_heads"
)
main
(
version
=
args
.
version
,
num_seqs
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
num_query_heads
=
args
.
num_query_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
head_size
=
args
.
head_size
,
block_size
=
args
.
block_size
,
use_alibi
=
args
.
use_alibi
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
kv_cache_dtype
=
args
.
kv_cache_dtype
,
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_quant.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
@
torch
.
inference_mode
()
def
main
(
num_tokens
:
int
,
hidden_size
:
int
,
static_scale
:
bool
,
quant_dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
do_profile
:
bool
=
False
,
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
"cuda"
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
scale
=
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float32
)
if
static_scale
else
None
def
run_cuda_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
num_iters
):
if
quant_dtype
==
torch
.
int8
:
ops
.
scaled_int8_quant
(
x
,
scale
)
else
:
ops
.
scaled_fp8_quant
(
x
,
scale
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
print
(
"Warming up..."
)
run_benchmark
=
run_cuda_benchmark
run_benchmark
(
num_iters
=
num_warmup_iters
,
profile
=
False
)
# Benchmark.
if
do_profile
:
latency
=
run_benchmark
(
num_iters
=
1
,
profile
=
True
)
else
:
latency
=
run_benchmark
(
num_iters
=
num_iters
,
profile
=
False
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
"__main__"
:
def
to_torch_dtype
(
dt
):
if
dt
==
"int8"
:
return
torch
.
int8
if
dt
==
"fp8"
:
return
torch
.
float8_e4m3fn
raise
ValueError
(
f
"Unsupported dtype:
{
dt
}
"
)
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the quantization (fp8 or int8) kernel."
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--static-scale"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--quant-dtype"
,
type
=
str
,
choices
=
[
"fp8"
,
"int8"
],
default
=
"int8"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-warmup-iters"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-iters"
,
type
=
int
,
default
=
100
,
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
,
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
num_tokens
=
args
.
num_tokens
,
hidden_size
=
args
.
hidden_size
,
static_scale
=
args
.
static_scale
,
quant_dtype
=
to_torch_dtype
(
args
.
quant_dtype
),
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
,
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_rmsnorm.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
from
typing
import
Optional
,
Union
import
torch
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
from
torch
import
nn
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm.triton_utils
import
triton
class
HuggingFaceRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
)
->
None
:
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
if
residual
is
not
None
:
x
=
x
+
residual
.
to
(
torch
.
float32
)
residual
=
x
.
to
(
orig_dtype
)
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
.
to
(
orig_dtype
)
*
self
.
weight
if
residual
is
None
:
return
x
else
:
return
x
,
residual
def
rmsnorm_naive
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
eps
:
float
=
1e-6
,
):
naive_norm
=
HuggingFaceRMSNorm
(
x
.
shape
[
-
1
],
eps
=
eps
)
naive_norm
.
weight
=
nn
.
Parameter
(
weight
)
naive_norm
=
naive_norm
.
to
(
x
.
device
)
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
residual
is
not
None
:
residual
=
residual
.
view
(
-
1
,
residual
.
shape
[
-
1
])
output
=
naive_norm
(
x
,
residual
)
if
isinstance
(
output
,
tuple
):
output
=
(
output
[
0
].
view
(
orig_shape
),
output
[
1
].
view
(
orig_shape
))
else
:
output
=
output
.
view
(
orig_shape
)
return
output
def
rmsnorm_flashinfer
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
residual
is
not
None
:
residual
=
residual
.
view
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
is
not
None
:
fused_add_rmsnorm
(
x
,
residual
,
weight
,
eps
)
output
=
(
x
,
residual
)
else
:
output
=
rmsnorm
(
x
,
weight
,
eps
)
if
isinstance
(
output
,
tuple
):
output
=
(
output
[
0
].
view
(
orig_shape
),
output
[
1
].
view
(
orig_shape
))
else
:
output
=
output
.
view
(
orig_shape
)
return
output
def
rmsnorm_vllm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
residual
is
not
None
:
residual
=
residual
.
view
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
is
not
None
:
vllm_ops
.
fused_add_rms_norm
(
x
,
residual
,
weight
,
eps
)
output
=
(
x
,
residual
)
else
:
out
=
torch
.
empty_like
(
x
)
vllm_ops
.
rms_norm
(
out
,
x
,
weight
,
eps
)
output
=
out
if
isinstance
(
output
,
tuple
):
output
=
(
output
[
0
].
view
(
orig_shape
),
output
[
1
].
view
(
orig_shape
))
else
:
output
=
output
.
view
(
orig_shape
)
return
output
def
calculate_diff
(
batch_size
,
seq_len
,
hidden_size
,
use_residual
=
True
):
dtype
=
torch
.
bfloat16
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
output_naive
=
rmsnorm_naive
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
output_flashinfer
=
rmsnorm_flashinfer
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
output_vllm
=
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
if
use_residual
:
output_naive
=
output_naive
[
0
]
output_flashinfer
=
output_flashinfer
[
0
]
output_vllm
=
output_vllm
[
0
]
print
(
f
"Naive output=
{
output_naive
}
"
)
print
(
f
"FlashInfer output=
{
output_flashinfer
}
"
)
print
(
f
"vLLM output=
{
output_vllm
}
"
)
if
torch
.
allclose
(
output_naive
,
output_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
seq_length_range
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
head_num_range
=
[
32
,
48
]
configs
=
list
(
itertools
.
product
(
head_num_range
,
batch_size_range
,
seq_length_range
))
def
get_benchmark
(
use_residual
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"head_num"
,
"batch_size"
,
"seq_len"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"huggingface"
,
"flashinfer"
,
"vllm"
],
line_names
=
[
"HuggingFace"
,
"FlashInfer"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
f
"rmsnorm-perf-
{
'with'
if
use_residual
else
'without'
}
-residual"
,
args
=
{},
)
)
def
benchmark
(
head_num
,
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
bfloat16
hidden_size
=
head_num
*
128
# assuming head_dim = 128
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"huggingface"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rmsnorm_naive
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
),
quantiles
=
quantiles
,
)
elif
provider
==
"flashinfer"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rmsnorm_flashinfer
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size"
,
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
128
,
help
=
"Sequence length"
,
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
4096
,
help
=
"Hidden size (2nd dimension) of the sequence"
,
)
parser
.
add_argument
(
"--use-residual"
,
action
=
"store_true"
,
help
=
"Whether to use residual connection"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/rmsnorm/"
,
help
=
"Path to save rmsnorm benchmark results"
,
)
args
=
parser
.
parse_args
()
# Run correctness test
calculate_diff
(
batch_size
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
hidden_size
=
args
.
hidden_size
,
use_residual
=
args
.
use_residual
,
)
# Get the benchmark function with proper use_residual setting
benchmark
=
get_benchmark
(
args
.
use_residual
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_rope.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
itertools
import
accumulate
from
typing
import
Optional
import
nvtx
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
,
get_rope
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
def
benchmark_rope_kernels_multi_lora
(
is_neox_style
:
bool
,
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
,
rotary_dim
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
max_position
:
int
=
8192
,
base
:
float
=
10000
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
rotary_dim
=
head_size
# silulating serving 4 LoRAs
scaling_factors
=
[
1
,
2
,
4
,
8
]
# batched RoPE can take multiple scaling factors
batched_rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
tuple
(
scaling_factors
)},
)
# non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior
non_batched_ropes
:
list
[
RotaryEmbedding
]
=
[]
for
scaling_factor
in
scaling_factors
:
non_batched_ropes
.
append
(
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
(
scaling_factor
,)},
)
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
# create query offsets for batched RoPE, we concat multiple kv cache
# together and each query needs to find the right kv cache of its type
offset_map
=
torch
.
tensor
(
list
(
accumulate
(
[
0
]
+
[
max_position
*
scaling_factor
*
2
for
scaling_factor
in
scaling_factors
[:
-
1
]
]
)
)
)
query_types
=
torch
.
randint
(
0
,
len
(
scaling_factors
),
(
batch_size
,
seq_len
),
device
=
device
)
# map query types to offsets
query_offsets
=
offset_map
[
query_types
]
# the kernel takes flattened offsets
flatten_offsets
=
query_offsets
.
flatten
()
# batched queries of the same type together for non-batched RoPE
queries
=
[
query
[
query_types
==
i
]
for
i
in
range
(
len
(
scaling_factors
))]
keys
=
[
key
[
query_types
==
i
]
for
i
in
range
(
len
(
scaling_factors
))]
packed_qkr
=
zip
(
queries
,
keys
,
non_batched_ropes
)
# synchronize before start timing
torch
.
cuda
.
synchronize
()
with
nvtx
.
annotate
(
"non-batched"
,
color
=
"yellow"
):
for
q
,
k
,
r
in
packed_qkr
:
r
.
forward
(
positions
,
q
,
k
)
torch
.
cuda
.
synchronize
()
with
nvtx
.
annotate
(
"batched"
,
color
=
"green"
):
batched_rope
.
forward
(
positions
,
query
,
key
,
flatten_offsets
)
torch
.
cuda
.
synchronize
()
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the rotary embedding kernels."
)
parser
.
add_argument
(
"--is-neox-style"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--rotary-dim"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
32
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"bfloat16"
,
"float"
],
default
=
"float"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
choices
=
[
"cuda:0"
,
"cuda:1"
],
default
=
"cuda:0"
)
args
=
parser
.
parse_args
()
print
(
args
)
benchmark_rope_kernels_multi_lora
(
is_neox_style
=
args
.
is_neox_style
,
batch_size
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
num_heads
=
args
.
num_heads
,
head_size
=
args
.
head_size
,
rotary_dim
=
args
.
rotary_dim
,
dtype
=
getattr
(
torch
,
args
.
dtype
),
seed
=
args
.
seed
,
device
=
args
.
device
,
)
online_apiserver_test_maxbs/benchmarks/kernels/benchmark_shapes.py
0 → 100644
View file @
d1a06223
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
WEIGHT_SHAPES
=
{
"ideal"
:
[[
4
*
256
*
32
,
256
*
32
]],
"mistralai/Mistral-7B-v0.1/TP1"
:
[
[
4096
,
6144
],
[
4096
,
4096
],
[
4096
,
28672
],
[
14336
,
4096
],
],
"mistralai/Mistral-7B-v0.1/TP2"
:
[
[
4096
,
3072
],
[
2048
,
4096
],
[
4096
,
14336
],
[
7168
,
4096
],
],
"mistralai/Mistral-7B-v0.1/TP4"
:
[
[
4096
,
1536
],
[
1024
,
4096
],
[
4096
,
7168
],
[
3584
,
4096
],
],
"meta-llama/Llama-2-7b-hf/TP1"
:
[
[
4096
,
12288
],
[
4096
,
4096
],
[
4096
,
22016
],
[
11008
,
4096
],
],
"meta-llama/Llama-2-7b-hf/TP2"
:
[
[
4096
,
6144
],
[
2048
,
4096
],
[
4096
,
11008
],
[
5504
,
4096
],
],
"meta-llama/Llama-2-7b-hf/TP4"
:
[
[
4096
,
3072
],
[
1024
,
4096
],
[
4096
,
5504
],
[
2752
,
4096
],
],
"meta-llama/Llama-2-13b-hf/TP1"
:
[
[
5120
,
15360
],
[
5120
,
5120
],
[
5120
,
27648
],
[
13824
,
5120
],
],
"meta-llama/Llama-2-13b-hf/TP2"
:
[
[
5120
,
7680
],
[
2560
,
5120
],
[
5120
,
13824
],
[
6912
,
5120
],
],
"meta-llama/Llama-2-13b-hf/TP4"
:
[
[
5120
,
3840
],
[
1280
,
5120
],
[
5120
,
6912
],
[
3456
,
5120
],
],
"meta-llama/Llama-2-70b-hf/TP1"
:
[
[
8192
,
10240
],
[
8192
,
8192
],
[
8192
,
57344
],
[
28672
,
8192
],
],
"meta-llama/Llama-2-70b-hf/TP2"
:
[
[
8192
,
5120
],
[
4096
,
8192
],
[
8192
,
28672
],
[
14336
,
8192
],
],
"meta-llama/Llama-2-70b-hf/TP4"
:
[
[
8192
,
2560
],
[
2048
,
8192
],
[
8192
,
14336
],
[
7168
,
8192
],
],
}
WEIGHT_SHAPES_MOE
=
{
"nm-testing/Mixtral-8x7B-Instruct-v0.1"
:
[
[
8
,
2
,
4096
,
28672
],
[
8
,
2
,
14336
,
4096
],
],
"nm-testing/deepseekv2-lite"
:
[
[
64
,
6
,
2048
,
1408
],
],
"ibm-granite/granite-3.0-1b-a400m"
:
[
[
32
,
8
,
1024
,
1024
],
],
"ibm-granite/granite-3.0-3b-a800m"
:
[
[
40
,
8
,
1024
,
1536
],
],
}
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment