Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
LLaVA_vllm
Commits
c2170174
Commit
c2170174
authored
Oct 23, 2025
by
laibao
Browse files
更新README.md,修改Docker镜像版本,调整环境变量格式,优化基准测试脚本,添加新的基准测试功能,删除不再使用的示例文件。
parent
ff2c99fd
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5956 additions
and
685 deletions
+5956
-685
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
+63
-0
benchmarks/disagg_benchmarks/round_robin_proxy.py
benchmarks/disagg_benchmarks/round_robin_proxy.py
+63
-0
benchmarks/disagg_benchmarks/visualize_benchmark_results.py
benchmarks/disagg_benchmarks/visualize_benchmark_results.py
+47
-0
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
+228
-0
benchmarks/kernels/bench_fp8_gemm.py
benchmarks/kernels/bench_fp8_gemm.py
+159
-0
benchmarks/kernels/bench_int8_gemm.py
benchmarks/kernels/bench_int8_gemm.py
+169
-0
benchmarks/kernels/benchmark_aqlm.py
benchmarks/kernels/benchmark_aqlm.py
+121
-78
benchmarks/kernels/benchmark_bitblas.py
benchmarks/kernels/benchmark_bitblas.py
+242
-0
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
+490
-0
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+383
-0
benchmarks/kernels/benchmark_layernorm.py
benchmarks/kernels/benchmark_layernorm.py
+39
-32
benchmarks/kernels/benchmark_lora.py
benchmarks/kernels/benchmark_lora.py
+1065
-0
benchmarks/kernels/benchmark_machete.py
benchmarks/kernels/benchmark_machete.py
+487
-175
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+254
-95
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+639
-160
benchmarks/kernels/benchmark_moe_align_block_size.py
benchmarks/kernels/benchmark_moe_align_block_size.py
+159
-0
benchmarks/kernels/benchmark_moe_int4.py
benchmarks/kernels/benchmark_moe_int4.py
+713
-0
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+418
-0
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+172
-108
benchmarks/kernels/benchmark_quant.py
benchmarks/kernels/benchmark_quant.py
+45
-37
No files found.
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
aiohttp
from
quart
import
Quart
,
make_response
,
request
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
)
app
=
Quart
(
__name__
)
async
def
forward_request
(
url
,
data
):
async
with
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
}
async
with
session
.
post
(
url
=
url
,
json
=
data
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
# if response.headers.get('Transfer-Encoding') == 'chunked':
if
True
:
async
for
chunk_bytes
in
response
.
content
.
iter_chunked
(
1024
):
yield
chunk_bytes
else
:
content
=
await
response
.
read
()
yield
content
@
app
.
route
(
"/v1/completions"
,
methods
=
[
"POST"
])
async
def
handle_request
():
try
:
original_request_data
=
await
request
.
get_json
()
prefill_request
=
original_request_data
.
copy
()
# change max_tokens = 1 to let it only do prefill
prefill_request
[
"max_tokens"
]
=
1
# finish prefill
async
for
_
in
forward_request
(
"http://localhost:8100/v1/completions"
,
prefill_request
):
continue
# return decode
generator
=
forward_request
(
"http://localhost:8200/v1/completions"
,
original_request_data
)
response
=
await
make_response
(
generator
)
response
.
timeout
=
None
return
response
except
Exception
as
e
:
import
sys
import
traceback
exc_info
=
sys
.
exc_info
()
print
(
"Error occurred in disagg prefill proxy server"
)
print
(
e
)
print
(
""
.
join
(
traceback
.
format_exception
(
*
exc_info
)))
if
__name__
==
"__main__"
:
app
.
run
(
port
=
8000
)
benchmarks/disagg_benchmarks/round_robin_proxy.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
itertools
import
aiohttp
from
aiohttp
import
web
class
RoundRobinProxy
:
def
__init__
(
self
,
target_ports
):
self
.
target_ports
=
target_ports
self
.
port_cycle
=
itertools
.
cycle
(
self
.
target_ports
)
async
def
handle_request
(
self
,
request
):
target_port
=
next
(
self
.
port_cycle
)
target_url
=
f
"http://localhost:
{
target_port
}{
request
.
path_qs
}
"
async
with
aiohttp
.
ClientSession
()
as
session
:
try
:
# Forward the request
async
with
session
.
request
(
method
=
request
.
method
,
url
=
target_url
,
headers
=
request
.
headers
,
data
=
request
.
content
,
)
as
response
:
# Start sending the response
resp
=
web
.
StreamResponse
(
status
=
response
.
status
,
headers
=
response
.
headers
)
await
resp
.
prepare
(
request
)
# Stream the response content
async
for
chunk
in
response
.
content
.
iter_any
():
await
resp
.
write
(
chunk
)
await
resp
.
write_eof
()
return
resp
except
Exception
as
e
:
return
web
.
Response
(
text
=
f
"Error:
{
str
(
e
)
}
"
,
status
=
500
)
async
def
main
():
proxy
=
RoundRobinProxy
([
8100
,
8200
])
app
=
web
.
Application
()
app
.
router
.
add_route
(
"*"
,
"/{path:.*}"
,
proxy
.
handle_request
)
runner
=
web
.
AppRunner
(
app
)
await
runner
.
setup
()
site
=
web
.
TCPSite
(
runner
,
"localhost"
,
8000
)
await
site
.
start
()
print
(
"Proxy server started on http://localhost:8000"
)
# Keep the server running
await
asyncio
.
Event
().
wait
()
if
__name__
==
"__main__"
:
asyncio
.
run
(
main
())
benchmarks/disagg_benchmarks/visualize_benchmark_results.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
matplotlib.pyplot
as
plt
import
pandas
as
pd
if
__name__
==
"__main__"
:
data
=
[]
for
name
in
[
"disagg_prefill"
,
"chunked_prefill"
]:
for
qps
in
[
2
,
4
,
6
,
8
]:
with
open
(
f
"results/
{
name
}
-qps-
{
qps
}
.json"
)
as
f
:
x
=
json
.
load
(
f
)
x
[
"name"
]
=
name
x
[
"qps"
]
=
qps
data
.
append
(
x
)
df
=
pd
.
DataFrame
.
from_dict
(
data
)
dis_df
=
df
[
df
[
"name"
]
==
"disagg_prefill"
]
chu_df
=
df
[
df
[
"name"
]
==
"chunked_prefill"
]
plt
.
style
.
use
(
"bmh"
)
plt
.
rcParams
[
"font.size"
]
=
20
for
key
in
[
"mean_ttft_ms"
,
"median_ttft_ms"
,
"p99_ttft_ms"
,
"mean_itl_ms"
,
"median_itl_ms"
,
"p99_itl_ms"
,
]:
fig
,
ax
=
plt
.
subplots
(
figsize
=
(
11
,
7
))
plt
.
plot
(
dis_df
[
"qps"
],
dis_df
[
key
],
label
=
"disagg_prefill"
,
marker
=
"o"
,
linewidth
=
4
)
plt
.
plot
(
chu_df
[
"qps"
],
chu_df
[
key
],
label
=
"chunked_prefill"
,
marker
=
"o"
,
linewidth
=
4
)
ax
.
legend
()
ax
.
set_xlabel
(
"QPS"
)
ax
.
set_ylabel
(
key
)
ax
.
set_ylim
(
bottom
=
0
)
fig
.
savefig
(
f
"results/
{
key
}
.png"
)
plt
.
close
(
fig
)
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pickle
as
pkl
import
time
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
itertools
import
product
from
typing
import
Callable
,
Optional
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
tqdm
import
tqdm
import
vllm._custom_ops
as
ops
from
vllm.model_executor.layers.layernorm
import
RMSNorm
@
dataclass
class
bench_params_t
:
num_tokens
:
int
hidden_size
:
int
add_residual
:
bool
dtype
:
torch
.
dtype
def
description
(
self
):
return
(
f
"N
{
self
.
num_tokens
}
"
f
"x D
{
self
.
hidden_size
}
"
f
"x R
{
self
.
add_residual
}
"
f
"x DT
{
self
.
dtype
}
"
)
def
get_bench_params
()
->
list
[
bench_params_t
]:
## Test Fixtures
NUM_TOKENS
=
[
2
**
x
for
x
in
range
(
11
)]
HIDDEN_SIZES
=
list
(
range
(
1024
,
8129
,
1024
))
ADD_RESIDUAL
=
[
True
,
False
]
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float
]
combinations
=
product
(
NUM_TOKENS
,
HIDDEN_SIZES
,
ADD_RESIDUAL
,
DTYPES
)
bench_params
=
list
(
map
(
lambda
x
:
bench_params_t
(
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
]),
combinations
)
)
return
bench_params
# Reference impls
def
unfused_int8_impl
(
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
quant_dtype
:
torch
.
dtype
,
):
# Norm
torch_out
=
None
if
residual
is
None
:
torch_out
=
rms_norm_layer
.
forward_cuda
(
x
,
residual
)
else
:
torch_out
,
_
=
rms_norm_layer
.
forward_cuda
(
x
,
residual
)
# Quant
torch_out
,
_
,
_
=
ops
.
scaled_int8_quant
(
torch_out
)
def
unfused_fp8_impl
(
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
quant_dtype
:
torch
.
dtype
,
):
# Norm
torch_out
=
None
if
residual
is
None
:
torch_out
=
rms_norm_layer
.
forward_cuda
(
x
,
residual
)
else
:
torch_out
,
_
=
rms_norm_layer
.
forward_cuda
(
x
,
residual
)
# Quant
torch_out
,
_
=
ops
.
scaled_fp8_quant
(
torch_out
)
def
fused_impl
(
rms_norm_layer
:
RMSNorm
,
# this stores the weights
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
quant_dtype
:
torch
.
dtype
,
):
out
,
_
=
ops
.
rms_norm_dynamic_per_token_quant
(
x
,
rms_norm_layer
.
weight
,
1e-6
,
quant_dtype
,
residual
=
residual
)
# Bench functions
def
bench_fn
(
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
label
:
str
,
sub_label
:
str
,
fn
:
Callable
,
description
:
str
,
)
->
TMeasurement
:
min_run_time
=
1
globals
=
{
"rms_norm_layer"
:
rms_norm_layer
,
"x"
:
x
,
"residual"
:
residual
,
"quant_dtype"
:
quant_dtype
,
"fn"
:
fn
,
}
return
TBenchmark
.
Timer
(
stmt
=
"fn(rms_norm_layer, x, residual, quant_dtype)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
description
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
def
bench
(
params
:
bench_params_t
,
label
:
str
,
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
# Make inputs
layer
=
RMSNorm
(
params
.
hidden_size
,
1e-6
).
to
(
dtype
=
params
.
dtype
)
# Make weights
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
# Make inputs
scale
=
1
/
params
.
hidden_size
x
=
(
torch
.
randn
(
params
.
num_tokens
,
params
.
hidden_size
,
dtype
=
params
.
dtype
,
device
=
"cuda"
)
*
scale
)
residual
=
(
(
torch
.
randn_like
(
x
)
*
scale
).
to
(
device
=
"cuda"
)
if
params
.
add_residual
else
None
)
timers
=
[]
# unfused int8 impl.
timers
.
append
(
bench_fn
(
layer
,
x
,
residual
,
torch
.
int8
,
label
,
sub_label
,
unfused_int8_impl
,
"unfused_int8_impl"
,
)
)
# unfused fp8 impl.
timers
.
append
(
bench_fn
(
layer
,
x
,
residual
,
torch
.
float8_e4m3fn
,
label
,
sub_label
,
unfused_fp8_impl
,
"unfused_fp8_impl"
,
)
)
# fused int8 impl.
timers
.
append
(
bench_fn
(
layer
,
x
,
residual
,
torch
.
int8
,
label
,
sub_label
,
fused_impl
,
"fused_int8_impl"
,
)
)
# fused fp8 impl.
timers
.
append
(
bench_fn
(
layer
,
x
,
residual
,
torch
.
float8_e4m3fn
,
label
,
sub_label
,
fused_impl
,
"fused_fp8_impl"
,
)
)
print_timers
(
timers
)
return
timers
# launch bench
# runner
def
print_timers
(
timers
:
Iterable
[
TMeasurement
]):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
def
main
():
torch
.
set_default_device
(
"cuda"
)
bench_params
=
get_bench_params
()
timers
=
[]
for
bp
in
tqdm
(
bench_params
):
timers
.
extend
(
bench
(
bp
,
"rms-norm-dynamic-per-token-quant"
,
bp
.
description
()))
print_timers
(
timers
)
# pickle all the results
timestamp
=
int
(
time
.
time
())
with
open
(
f
"rms_norm_dpt_quant-
{
timestamp
}
.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
(
timers
,
f
)
if
__name__
==
"__main__"
:
main
()
benchmarks/kernels/bench_fp8_gemm.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_fp8_quant
as
vllm_scaled_fp8_quant
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"fp8-tensor-w-token-a"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
False
),
"fp8-tensor-w-tensor-a"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
True
),
"fp8-channel-w-token-a"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
True
),
"fp8-channel-w-tensor-a"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
False
),
"fp8-tensor-w-token-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
False
),
"fp8-tensor-w-tensor-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
True
),
"fp8-channel-w-token-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
True
),
"fp8-channel-w-tensor-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
False
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
_quant_weight_fp8
(
b
:
torch
.
Tensor
,
w_type
:
str
,
device
:
str
):
if
w_type
==
"tensor"
:
scale_b
=
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
else
:
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
use_per_token_if_dynamic
=
True
)
return
b_fp8
.
t
(),
scale_b_fp8
def
build_fp8_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
b_fp8
,
scale_b_fp8
=
_quant_weight_fp8
(
b
,
cfg
[
"w"
],
device
)
scale_a_const
=
(
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
if
cfg
[
"a"
]
==
"tensor"
else
None
)
if
cfg
[
"no_a_quant"
]:
if
cfg
[
"a"
]
==
"tensor"
:
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a_const
)
else
:
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
def
run
():
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
return
run
if
cfg
[
"a"
]
==
"tensor"
:
def
run
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a_const
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
else
:
def
run
():
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
return
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs FP8 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_fp8_runner
(
cfg
,
a
,
b
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs FP8 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_fp8_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
benchmarks/kernels/bench_int8_gemm.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_int8_quant
as
vllm_scaled_int8_quant
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"int8-tensor-w-token-a"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
False
),
"int8-tensor-w-tensor-a"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
True
),
"int8-channel-w-token-a"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
False
,
enabled
=
True
),
"int8-channel-w-tensor-a"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
False
,
enabled
=
False
),
"int8-tensor-w-token-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
False
),
"int8-tensor-w-tensor-a-noquant"
:
dict
(
w
=
"tensor"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
True
),
"int8-channel-w-token-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"token"
,
no_a_quant
=
True
,
enabled
=
True
),
"int8-channel-w-tensor-a-noquant"
:
dict
(
w
=
"channel"
,
a
=
"tensor"
,
no_a_quant
=
True
,
enabled
=
False
),
}
def
_quant_weight
(
b
,
w_type
,
device
):
if
w_type
==
"tensor"
:
scale_b
=
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
,
scale_b
)
assert
scale_b_int8
.
numel
()
==
1
else
:
# channel
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
)
assert
scale_b_int8
.
numel
()
==
b
.
shape
[
0
]
return
b_int8
.
t
(),
scale_b_int8
def
build_int8_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
# quant before running the kernel
b_int8
,
scale_b_int8
=
_quant_weight
(
b
,
cfg
[
"w"
],
device
)
scale_a_const
=
None
if
cfg
[
"a"
]
==
"tensor"
:
scale_a_const
=
torch
.
ones
(
1
,
device
=
device
,
dtype
=
torch
.
float32
)
# no quant, create activation ahead
if
cfg
[
"no_a_quant"
]:
if
cfg
[
"a"
]
==
"tensor"
:
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a_const
)
else
:
# token
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
def
run_quant
():
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
return
run_quant
# dynamic quant, create activation inside
if
cfg
[
"a"
]
==
"tensor"
:
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a_const
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
else
:
# token
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
return
run_quant
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
.
get
(
"enabled"
)]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
[
k
for
k
in
_enabled
],
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs INT8 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_int8_runner
(
cfg
,
a
,
b
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
KN_model_names
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
KN_model_names
.
append
(
KN
)
return
KN_model_names
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
)
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs INT8 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_int8_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
benchmarks/kernels/benchmark_aqlm.py
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
sys
from
typing
import
Optional
...
...
@@ -7,32 +10,39 @@ import torch.nn.functional as F
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.aqlm
import
(
dequantize_weight
,
generic_dequantize_gemm
,
get_int_dtype
,
optimized_dequantize_gemm
)
dequantize_weight
,
generic_dequantize_gemm
,
get_int_dtype
,
optimized_dequantize_gemm
,
)
from
vllm.utils
import
FlexibleArgumentParser
os
.
environ
[
'
CUDA_VISIBLE_DEVICES
'
]
=
'0'
os
.
environ
[
"
CUDA_VISIBLE_DEVICES
"
]
=
"0"
def
torch_mult
(
input
:
torch
.
Tensor
,
# [..., in_features]
weights
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
# [..., in_features]
input
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
scales
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
output
=
F
.
linear
(
input
,
weights
)
return
output
def
dequant_out_scale
(
input
:
torch
.
Tensor
,
# [..., in_features]
codes
:
torch
.
IntTensor
,
# [num_out_groups, num_in_groups, num_codebooks]
codebooks
:
torch
.
Tensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
scales
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
# [..., in_features]
input
:
torch
.
Tensor
,
# [num_out_groups, num_in_groups, num_codebooks]
codes
:
torch
.
IntTensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
scales
:
torch
.
Tensor
,
output_partition_sizes
:
torch
.
IntTensor
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
weights
=
ops
.
aqlm_dequant
(
codes
,
codebooks
,
output_partition_sizes
)
if
bias
is
None
:
...
...
@@ -44,40 +54,42 @@ def dequant_out_scale(
flattened_output
*=
b_scales
return
flattened_output
.
view
(
orig_shape
)
else
:
b_scales
=
scales
.
view
(
scales
.
shape
[:
-
3
]
+
(
-
1
,
)).
expand
(
-
1
,
weights
.
shape
[
1
])
b_scales
=
scales
.
view
(
scales
.
shape
[:
-
3
]
+
(
-
1
,)).
expand
(
-
1
,
weights
.
shape
[
1
])
weights
*=
b_scales
return
F
.
linear
(
input
,
weights
,
bias
)
def
dequant_weight_scale
(
input
:
torch
.
Tensor
,
# [..., in_features]
codes
:
torch
.
IntTensor
,
# [num_out_groups, num_in_groups, num_codebooks]
codebooks
:
torch
.
Tensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
scales
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
# [..., in_features]
input
:
torch
.
Tensor
,
# [num_out_groups, num_in_groups, num_codebooks]
codes
:
torch
.
IntTensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
scales
:
torch
.
Tensor
,
output_partition_sizes
:
torch
.
IntTensor
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
weights
=
ops
.
aqlm_dequant
(
codes
,
codebooks
,
output_partition_sizes
)
b_scales
=
scales
.
view
(
scales
.
shape
[:
-
3
]
+
(
-
1
,
)).
expand
(
-
1
,
weights
.
shape
[
1
])
b_scales
=
scales
.
view
(
scales
.
shape
[:
-
3
]
+
(
-
1
,)).
expand
(
-
1
,
weights
.
shape
[
1
])
weights
*=
b_scales
return
F
.
linear
(
input
,
weights
,
bias
)
def
dequant_no_scale
(
input
:
torch
.
Tensor
,
# [..., in_features]
codes
:
torch
.
IntTensor
,
# [num_out_groups, num_in_groups, num_codebooks]
codebooks
:
torch
.
Tensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
scales
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
# [..., in_features]
input
:
torch
.
Tensor
,
# [num_out_groups, num_in_groups, num_codebooks]
codes
:
torch
.
IntTensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
scales
:
torch
.
Tensor
,
output_partition_sizes
:
torch
.
IntTensor
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
weights
=
ops
.
aqlm_dequant
(
codes
,
codebooks
,
output_partition_sizes
)
return
F
.
linear
(
input
,
weights
,
bias
)
...
...
@@ -87,23 +99,26 @@ def dequant_no_scale(
# the generic pytorch version.
# Just visual comparison.
def
dequant_test
(
k
:
int
,
parts
:
torch
.
Tensor
,
nbooks
:
int
,
bits
:
int
)
->
None
:
n
=
int
(
parts
.
sum
().
item
())
device
=
torch
.
device
(
'
cuda:0
'
)
device
=
torch
.
device
(
"
cuda:0
"
)
code_range
=
(
1
<<
bits
)
//
2
ingroups
=
8
codes
=
torch
.
randint
(
-
code_range
,
code_range
,
size
=
(
n
,
k
//
ingroups
,
nbooks
),
dtype
=
get_int_dtype
(
bits
),
device
=
device
)
codes
=
torch
.
randint
(
-
code_range
,
code_range
,
size
=
(
n
,
k
//
ingroups
,
nbooks
),
dtype
=
get_int_dtype
(
bits
),
device
=
device
,
)
codebooks
=
torch
.
randn
(
size
=
(
parts
.
shape
[
0
]
*
nbooks
,
1
<<
bits
,
1
,
8
),
dtype
=
torch
.
float16
,
device
=
device
)
codebooks
=
torch
.
randn
(
size
=
(
parts
.
shape
[
0
]
*
nbooks
,
1
<<
bits
,
1
,
8
),
dtype
=
torch
.
float16
,
device
=
device
,
)
count
=
0
for
index
in
range
(
16
):
...
...
@@ -136,24 +151,25 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
def
main
():
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark aqlm performance."
)
# Add arguments
parser
.
add_argument
(
"--nbooks"
,
type
=
int
,
default
=
1
,
help
=
"Number of codebooks (default: 1)"
)
parser
.
add_argument
(
"--bits"
,
type
=
int
,
default
=
16
,
help
=
"Number of bits per code element (default: 16)"
)
parser
.
add_argument
(
"--nbooks"
,
type
=
int
,
default
=
1
,
help
=
"Number of codebooks (default: 1)"
)
parser
.
add_argument
(
"--bits"
,
type
=
int
,
default
=
16
,
help
=
"Number of bits per code element (default: 16)"
,
)
parser
.
add_argument
(
"--test"
,
type
=
bool
,
default
=
False
,
help
=
"Run the decompression/dequant tester rather than benchmarking "
"(default: False)"
)
"(default: False)"
,
)
# Parse the arguments
args
=
parser
.
parse_args
()
...
...
@@ -163,7 +179,7 @@ def main():
bits
=
args
.
bits
if
args
.
test
:
dequant_test
(
4096
,
torch
.
tensor
((
4096
,
)),
nbooks
,
bits
)
dequant_test
(
4096
,
torch
.
tensor
((
4096
,)),
nbooks
,
bits
)
return
# Otherwise, benchmark.
...
...
@@ -182,31 +198,54 @@ def main():
with
open
(
filename
,
"w"
)
as
f
:
sys
.
stdout
=
f
print
(
'
m | k | n | n parts
'
,
end
=
''
)
print
(
"
m | k | n | n parts
"
,
end
=
""
)
for
method
in
methods
:
print
(
f
" |
{
method
.
__name__
.
replace
(
'_'
,
' '
)
}
(µs)"
,
end
=
''
)
print
(
''
)
print
(
f
" |
{
method
.
__name__
.
replace
(
'_'
,
' '
)
}
(µs)"
,
end
=
""
)
print
(
""
)
# These are reasonable prefill sizes.
ksandpartions
=
((
4096
,
(
4096
,
4096
,
4096
)),
(
4096
,
(
4096
,
)),
(
4096
,
(
11008
,
11008
)),
(
11008
,
(
4096
,
)))
ksandpartions
=
(
(
4096
,
(
4096
,
4096
,
4096
)),
(
4096
,
(
4096
,)),
(
4096
,
(
11008
,
11008
)),
(
11008
,
(
4096
,)),
)
# reasonable ranges for m.
for
m
in
[
1
,
2
,
4
,
8
,
10
,
12
,
14
,
16
,
24
,
32
,
48
,
52
,
56
,
64
,
96
,
112
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
1
,
2
,
4
,
8
,
10
,
12
,
14
,
16
,
24
,
32
,
48
,
52
,
56
,
64
,
96
,
112
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
,
]:
print
(
f
'
{
m
}
'
,
file
=
sys
.
__stdout__
)
print
(
f
"
{
m
}
"
,
file
=
sys
.
__stdout__
)
for
ksp
in
ksandpartions
:
run_grid
(
m
,
ksp
[
0
],
torch
.
tensor
(
ksp
[
1
]),
nbooks
,
bits
,
methods
)
run_grid
(
m
,
ksp
[
0
],
torch
.
tensor
(
ksp
[
1
]),
nbooks
,
bits
,
methods
)
sys
.
stdout
=
sys
.
__stdout__
def
run_grid
(
m
:
int
,
k
:
int
,
parts
:
torch
.
Tensor
,
nbooks
:
int
,
bits
:
int
,
methods
):
def
run_grid
(
m
:
int
,
k
:
int
,
parts
:
torch
.
Tensor
,
nbooks
:
int
,
bits
:
int
,
methods
):
# I didn't see visible improvements from increasing these, but feel free :)
num_warmup_trials
=
1
num_trials
=
1
...
...
@@ -227,7 +266,7 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
)
n
=
parts
.
sum
().
item
()
print
(
f
'
{
m
}
|
{
k
}
|
{
n
}
|
{
parts
.
tolist
()
}
'
,
end
=
''
)
print
(
f
"
{
m
}
|
{
k
}
|
{
n
}
|
{
parts
.
tolist
()
}
"
,
end
=
""
)
for
method
in
methods
:
best_time_us
=
1e20
...
...
@@ -247,32 +286,36 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
if
kernel_dur_us
<
best_time_us
:
best_time_us
=
kernel_dur_us
print
(
f
'
|
{
kernel_dur_us
:.
0
f
}
'
,
end
=
''
)
print
(
f
"
|
{
kernel_dur_us
:.
0
f
}
"
,
end
=
""
)
print
(
''
)
print
(
""
)
def
run_timing
(
num_calls
:
int
,
m
:
int
,
k
:
int
,
parts
:
torch
.
Tensor
,
nbooks
:
int
,
bits
:
int
,
method
)
->
float
:
def
run_timing
(
num_calls
:
int
,
m
:
int
,
k
:
int
,
parts
:
torch
.
Tensor
,
nbooks
:
int
,
bits
:
int
,
method
)
->
float
:
n
=
int
(
parts
.
sum
().
item
())
device
=
torch
.
device
(
'
cuda:0
'
)
device
=
torch
.
device
(
"
cuda:0
"
)
input
=
torch
.
randn
((
1
,
m
,
k
),
dtype
=
torch
.
float16
,
device
=
device
)
code_range
=
(
1
<<
bits
)
//
2
ingroups
=
8
codes
=
torch
.
randint
(
-
code_range
,
code_range
,
size
=
(
n
,
k
//
ingroups
,
nbooks
),
dtype
=
get_int_dtype
(
bits
),
device
=
device
)
codebooks
=
torch
.
randn
(
size
=
(
parts
.
shape
[
0
]
*
nbooks
,
1
<<
bits
,
1
,
8
),
dtype
=
torch
.
float16
,
device
=
device
)
codes
=
torch
.
randint
(
-
code_range
,
code_range
,
size
=
(
n
,
k
//
ingroups
,
nbooks
),
dtype
=
get_int_dtype
(
bits
),
device
=
device
,
)
codebooks
=
torch
.
randn
(
size
=
(
parts
.
shape
[
0
]
*
nbooks
,
1
<<
bits
,
1
,
8
),
dtype
=
torch
.
float16
,
device
=
device
,
)
scales
=
torch
.
randn
(
size
=
(
n
,
1
,
1
,
1
),
dtype
=
torch
.
float16
,
device
=
device
)
...
...
benchmarks/kernels/benchmark_bitblas.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from
vllm.model_executor.layers.quantization.utils.bitblas_utils
import
(
MINIMUM_BITBLAS_VERSION
,
)
try
:
import
bitblas
if
bitblas
.
__version__
<
MINIMUM_BITBLAS_VERSION
:
raise
ImportError
(
"bitblas version is wrong. Please "
f
"install bitblas>=
{
MINIMUM_BITBLAS_VERSION
}
"
)
except
ImportError
as
e
:
bitblas_import_exception
=
e
raise
ValueError
(
"Trying to use the bitblas backend, but could not import"
f
"with the following error:
{
bitblas_import_exception
}
. "
"Please install bitblas through the following command: "
f
"`pip install bitblas>=
{
MINIMUM_BITBLAS_VERSION
}
`"
)
from
bitblas_import_exception
from
bitblas
import
Matmul
,
MatmulConfig
,
auto_detect_nvidia_target
from
vllm.utils
import
FlexibleArgumentParser
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark BitBLAS int4 on a specific target."
)
# Add arguments to the parser
parser
.
add_argument
(
"--target"
,
type
=
str
,
default
=
auto_detect_nvidia_target
(),
help
=
"Specify the target device for benchmarking."
,
)
parser
.
add_argument
(
"--group_size"
,
type
=
int
,
default
=
None
,
help
=
"Group size for grouped quantization."
)
parser
.
add_argument
(
"--A_dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int8"
],
help
=
"Data type of activation A."
,
)
parser
.
add_argument
(
"--W_dtype"
,
type
=
str
,
default
=
"int4"
,
choices
=
[
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int8"
,
"int4"
,
"int2"
,
"int1"
,
"nf4"
,
"fp4_e2m1"
,
],
help
=
"Data type of weight W."
,
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"int32"
],
help
=
"Data type for accumulation."
,
)
parser
.
add_argument
(
"--out_dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"float32"
,
"int32"
,
"int8"
],
help
=
"Data type for output."
,
)
parser
.
add_argument
(
"--layout"
,
type
=
str
,
default
=
"nt"
,
choices
=
[
"nt"
,
"nn"
],
help
=
"Matrix layout, 'nt' for non-transpose A and transpose W."
,
)
parser
.
add_argument
(
"--with_bias"
,
action
=
"store_true"
,
help
=
"Include bias in the benchmark."
)
parser
.
add_argument
(
"--with_scaling"
,
action
=
"store_true"
,
help
=
"Include scaling factor in the quantization."
,
)
parser
.
add_argument
(
"--with_zeros"
,
action
=
"store_true"
,
help
=
"Include zeros in the quantization."
)
parser
.
add_argument
(
"--zeros_mode"
,
type
=
str
,
default
=
None
,
choices
=
[
"original"
,
"rescale"
,
"quantized"
],
help
=
"Specify the mode for calculating zeros."
,
)
# Parse the arguments
args
=
parser
.
parse_args
()
# Assign arguments to variables
target
=
args
.
target
A_dtype
=
args
.
A_dtype
W_dtype
=
args
.
W_dtype
accum_dtype
=
args
.
accum_dtype
out_dtype
=
args
.
out_dtype
layout
=
args
.
layout
with_bias
=
args
.
with_bias
group_size
=
args
.
group_size
with_scaling
=
args
.
with_scaling
with_zeros
=
args
.
with_zeros
zeros_mode
=
args
.
zeros_mode
# Define a list of shared arguments that repeat in every config
shared_args
=
[
A_dtype
,
W_dtype
,
out_dtype
,
accum_dtype
,
layout
,
with_bias
,
group_size
,
with_scaling
,
with_zeros
,
zeros_mode
,
]
# Define just the (M, K, N) shapes in a more compact list
shapes
=
[
# square test
(
1
,
16384
,
16384
),
# BLOOM-176B
(
1
,
43008
,
14336
),
(
1
,
14336
,
14336
),
(
1
,
57344
,
14336
),
(
1
,
14336
,
57344
),
# OPT-65B
(
1
,
9216
,
9216
),
(
1
,
36864
,
9216
),
(
1
,
9216
,
36864
),
(
1
,
22016
,
8192
),
# LLAMA-70B/65B
(
1
,
8192
,
22016
),
(
1
,
8192
,
8192
),
(
1
,
28672
,
8192
),
(
1
,
8192
,
28672
),
# square test
(
16384
,
16384
,
16384
),
# BLOOM-176B
(
8192
,
43008
,
14336
),
(
8192
,
14336
,
14336
),
(
8192
,
57344
,
14336
),
(
8192
,
14336
,
57344
),
# OPT-65B
(
8192
,
9216
,
9216
),
(
8192
,
36864
,
9216
),
(
8192
,
9216
,
36864
),
(
8192
,
22016
,
8192
),
# LLAMA-70B/65B
(
8192
,
8192
,
22016
),
(
8192
,
8192
,
8192
),
(
8192
,
28672
,
8192
),
(
8192
,
8192
,
28672
),
]
# Build test shapes with all the shared arguments
test_shapes
=
[(
MatmulConfig
,
Matmul
,
(
*
shape
,
*
shared_args
))
for
shape
in
shapes
]
benchmark_sets
=
[]
benchmark_sets
.
extend
(
test_shapes
)
benchmark_results
=
{}
for
config_class
,
operator
,
input_args
in
benchmark_sets
:
config
=
config_class
(
*
input_args
)
matmul
=
operator
(
config
,
target
=
target
,
enable_tuning
=
True
)
kernel_latency
=
matmul
.
profile_latency
()
print
(
"Time cost is: {:.3f} ms"
.
format
(
kernel_latency
))
profile_config
=
{
f
"
{
operator
.
__name__
}
-
{
'-'
.
join
([
str
(
i
)
for
i
in
input_args
])
}
"
:
{
"BitBLAS_top20_latency"
:
kernel_latency
,
}
}
benchmark_results
.
update
(
profile_config
)
# Define headers for the table
headers
=
[
"PrimFunc"
,
"Input Arguments"
,
"BitBLAS Top20 Latency"
,
]
# Calculate column widths for pretty printing
col_widths
=
[
0
,
0
,
0
]
for
config_key
,
values
in
benchmark_results
.
items
():
args_split
=
config_key
.
split
(
"-"
)
func_name
=
args_split
[
0
]
input_args_str
=
"-"
.
join
(
args_split
[
1
:])
col_widths
[
0
]
=
max
(
col_widths
[
0
],
len
(
func_name
)
+
2
,
len
(
headers
[
0
])
+
2
)
col_widths
[
1
]
=
max
(
col_widths
[
1
],
len
(
input_args_str
)
+
2
,
len
(
headers
[
1
])
+
2
)
col_widths
[
2
]
=
max
(
col_widths
[
2
],
len
(
f
"
{
values
[
'BitBLAS_top20_latency'
]:.
3
f
}
ms"
)
+
2
,
len
(
headers
[
2
])
+
2
,
)
# break only if you want to measure widths from a single example;
# otherwise, let it loop over all items.
# Print header
for
i
,
header
in
enumerate
(
headers
):
headers
[
i
]
=
header
.
ljust
(
col_widths
[
i
])
print
(
""
.
join
(
headers
))
print
(
"-"
*
sum
(
col_widths
))
# Print rows
for
config_key
,
values
in
benchmark_results
.
items
():
args_split
=
config_key
.
split
(
"-"
)
func_name
=
args_split
[
0
]
input_args_str
=
"-"
.
join
(
args_split
[
1
:])
row
=
[
func_name
,
input_args_str
,
f
"
{
values
[
'BitBLAS_top20_latency'
]:.
3
f
}
ms"
,
]
row_str
=
""
.
join
(
[
str
(
cell
).
ljust
(
col_widths
[
idx
])
for
idx
,
cell
in
enumerate
(
row
)]
)
print
(
row_str
)
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe
kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit
activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8)
and 16-bit activations.
"""
import
nvtx
import
torch
import
torch.utils.benchmark
as
benchmark
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp4
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
WEIGHT_SHAPES_MOE
=
{
"nvidia/DeepSeek-R1-FP4"
:
[
[
256
,
8
,
2048
,
7168
],
],
}
DEFAULT_MODELS
=
[
"nvidia/DeepSeek-R1-FP4"
,
]
DEFAULT_BATCH_SIZES
=
[
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
]
DEFAULT_TP_SIZES
=
[
1
]
PER_ACT_TOKEN_OPTS
=
[
False
]
PER_OUT_CH_OPTS
=
[
False
]
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
def
to_fp8
(
tensor
:
torch
.
Tensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
bench_run
(
results
:
list
[
benchmark
.
Measurement
],
model
:
str
,
num_experts
:
int
,
topk
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
mkn
:
tuple
[
int
,
int
,
int
],
):
label
=
"NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton"
sub_label
=
(
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})"
.
format
(
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
)
)
print
(
f
"Testing:
{
sub_label
}
"
)
(
m
,
k
,
n
)
=
mkn
dtype
=
torch
.
half
device
=
"cuda"
a
=
torch
.
randn
((
m
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
num_experts
,
2
*
n
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
num_experts
,
k
,
n
),
device
=
device
,
dtype
=
dtype
)
/
10
_
,
a_fp8_scale
=
ops
.
scaled_fp8_quant
(
a
)
w1_fp8q
=
torch
.
empty
(
(
num_experts
,
2
*
n
,
k
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
w2_fp8q
=
torch
.
empty
((
num_experts
,
k
,
n
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
w1_fp8scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
w2_fp8scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
num_experts
):
w1_fp8q
[
expert
],
w1_fp8scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w1
[
expert
])
w2_fp8q
[
expert
],
w2_fp8scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w2
[
expert
])
w1_fp8q_notransp
=
w1_fp8q
.
clone
()
w2_fp8q_notransp
=
w2_fp8q
.
clone
()
w1_fp8q
=
w1_fp8q
.
transpose
(
1
,
2
)
w2_fp8q
=
w2_fp8q
.
transpose
(
1
,
2
)
score
=
torch
.
randn
((
m
,
num_experts
),
device
=
device
,
dtype
=
dtype
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
False
)
quant_blocksize
=
16
w1_blockscale
=
torch
.
empty
(
(
num_experts
,
2
*
n
,
k
//
quant_blocksize
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
,
)
w2_blockscale
=
torch
.
empty
(
(
num_experts
,
k
,
n
//
quant_blocksize
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
# n_b_scales = 2 * n if per_out_ch else 1
# k_b_scales = k if per_out_ch else 1
w1_fp4
=
torch
.
empty
((
num_experts
,
2
*
n
,
k
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
w2_fp4
=
torch
.
empty
((
num_experts
,
k
,
n
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
w1_gs
=
torch
.
empty
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
float32
)
w2_gs
=
torch
.
empty
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
float32
)
a1_gs
=
torch
.
ones
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
float32
)
a2_gs
=
torch
.
ones
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
num_experts
):
w1_e
=
w1
[
expert
]
w2_e
=
w2
[
expert
]
w1_amax
=
torch
.
abs
(
w1_e
).
max
().
to
(
torch
.
float32
)
w2_amax
=
torch
.
abs
(
w2_e
).
max
().
to
(
torch
.
float32
)
w1_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w1_amax
w2_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w2_amax
w1_fp4
[
expert
],
w1_blockscale
[
expert
]
=
ops
.
scaled_fp4_quant
(
w1_e
,
w1_gs
[
expert
]
)
w2_fp4
[
expert
],
w2_blockscale
[
expert
]
=
ops
.
scaled_fp4_quant
(
w2_e
,
w2_gs
[
expert
]
)
def
run_triton_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_fp8_scale
:
torch
.
Tensor
,
num_repeats
:
int
,
):
for
_
in
range
(
num_repeats
):
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_fp8_scale
,
)
def
run_cutlass_moe_fp4
(
a
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w1_gs
:
torch
.
Tensor
,
w2_gs
:
torch
.
Tensor
,
a1_gs
:
torch
.
Tensor
,
a2_gs
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
,
num_repeats
:
int
,
):
for
_
in
range
(
num_repeats
):
with
nvtx
.
annotate
(
"cutlass_moe_fp4"
,
color
=
"green"
):
cutlass_moe_fp4
(
a
=
a
,
a1_gscale
=
a1_gs
,
a2_gscale
=
a2_gs
,
w1_fp4
=
w1_fp4
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
w1_gs
,
w2_fp4
=
w2_fp4
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
w2_gs
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
num_experts
,
device
=
device
,
)
def
run_cutlass_from_graph
(
a
:
torch
.
Tensor
,
a1_gscale
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_alphas
:
torch
.
Tensor
,
a2_gscale
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w2_alphas
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
,
):
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
):
return
cutlass_moe_fp4
(
a
=
a
,
a1_gscale
=
a1_gs
,
w1_fp4
=
w1_fp4
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
w1_alphas
,
a2_gscale
=
a2_gs
,
w2_fp4
=
w2_fp4
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
w2_alphas
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
num_experts
,
device
=
device
,
)
def
run_triton_from_graph
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_fp8_scale
:
torch
.
Tensor
,
):
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
):
return
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_fp8_scale
,
)
def
replay_graph
(
graph
,
num_repeats
):
for
_
in
range
(
num_repeats
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
cutlass_stream
=
torch
.
cuda
.
Stream
()
cutlass_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
cutlass_graph
,
stream
=
cutlass_stream
):
run_cutlass_from_graph
(
a
=
a
,
a1_gscale
=
a1_gs
,
w1_fp4
=
w1_fp4
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
w1_gs
,
a2_gscale
=
a2_gs
,
w2_fp4
=
w2_fp4
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
w2_gs
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
num_experts
,
device
=
device
,
)
torch
.
cuda
.
synchronize
()
triton_stream
=
torch
.
cuda
.
Stream
()
triton_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
triton_graph
,
stream
=
triton_stream
):
run_triton_from_graph
(
a
,
w1_fp8q_notransp
,
w2_fp8q_notransp
,
topk_weights
,
topk_ids
,
w1_fp8scale
,
w2_fp8scale
,
a_fp8_scale
,
)
torch
.
cuda
.
synchronize
()
min_run_time
=
5
num_warmup
=
5
num_runs
=
25
globals
=
{
# Baseline params
"w1"
:
w1
,
"w2"
:
w2
,
"score"
:
score
,
"topk"
:
topk
,
"w1_fp8q_notransp"
:
w1_fp8q_notransp
,
"w2_fp8q_notransp"
:
w2_fp8q_notransp
,
"w1_fp8scale"
:
w1_fp8scale
,
"w2_fp8scale"
:
w2_fp8scale
,
"a_fp8_scale"
:
a_fp8_scale
,
# Cutlass params
"a"
:
a
,
"a1_gscale"
:
a1_gs
,
"w1_fp4"
:
w1_fp4
,
"w1_blockscale"
:
w1_blockscale
,
"w1_alphas"
:
w1_gs
,
"a2_gscale"
:
a2_gs
,
"w2_fp4"
:
w2_fp4
,
"w2_blockscale"
:
w2_blockscale
,
"w2_alphas"
:
w2_gs
,
"topk_weights"
:
topk_weights
,
"topk_ids"
:
topk_ids
,
"m"
:
m
,
"n"
:
n
,
"k"
:
k
,
"e"
:
num_experts
,
"device"
:
device
,
# cuda graph params
"cutlass_graph"
:
cutlass_graph
,
"triton_graph"
:
triton_graph
,
# Gen params
"num_runs"
:
num_runs
,
# Kernels
"run_triton_moe"
:
run_triton_moe
,
"run_cutlass_moe_fp4"
:
run_cutlass_moe_fp4
,
"replay_graph"
:
replay_graph
,
}
# Warmup
run_triton_moe
(
a
,
w1_fp8q_notransp
,
w2_fp8q_notransp
,
topk_weights
,
topk_ids
,
w1_fp8scale
,
w2_fp8scale
,
a_fp8_scale
,
num_warmup
,
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
replay_graph
(
triton_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(triton_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
run_cutlass_moe_fp4
(
a
,
w1_fp4
,
w2_fp4
,
w1_blockscale
,
w2_blockscale
,
w1_gs
,
w2_gs
,
a1_gs
,
a2_gs
,
topk_weights
,
topk_ids
,
m
,
n
,
k
,
num_experts
,
device
,
num_warmup
,
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"cutlass_moe_fp4"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
replay_graph
(
cutlass_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(cutlass_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"cutlass_moe_fp4_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
def
main
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
results
:
list
[
benchmark
.
Measurement
]
=
[]
for
model
in
args
.
models
:
for
tp
in
args
.
tp_sizes
:
for
layer
in
WEIGHT_SHAPES_MOE
[
model
]:
num_experts
=
layer
[
0
]
topk
=
layer
[
1
]
size_k
=
layer
[
2
]
size_n
=
layer
[
3
]
//
tp
if
len
(
args
.
limit_k
)
>
0
and
size_k
not
in
args
.
limit_k
:
continue
if
len
(
args
.
limit_n
)
>
0
and
size_n
not
in
args
.
limit_n
:
continue
for
per_act_token
in
PER_ACT_TOKEN_OPTS
:
for
per_out_ch
in
PER_OUT_CH_OPTS
:
for
size_m
in
args
.
batch_sizes
:
mkn
=
(
size_m
,
size_k
,
size_n
)
bench_run
(
results
,
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
,
)
compare
=
benchmark
.
Compare
(
results
)
compare
.
print
()
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES_MOE
.
keys
(),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-num-groups"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-act-token"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-out-ch"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.utils.benchmark
as
benchmark
from
benchmark_shapes
import
WEIGHT_SHAPES_MOE
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp8
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_topk
,
)
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"nm-testing/Mixtral-8x7B-Instruct-v0.1"
,
"nm-testing/deepseekv2-lite"
,
"ibm-granite/granite-3.0-1b-a400m"
,
"ibm-granite/granite-3.0-3b-a800m"
,
]
DEFAULT_BATCH_SIZES
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
]
DEFAULT_TP_SIZES
=
[
1
]
PER_ACT_TOKEN_OPTS
=
[
False
]
PER_OUT_CH_OPTS
=
[
False
]
def
to_fp8
(
tensor
:
torch
.
Tensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
bench_run
(
results
:
list
[
benchmark
.
Measurement
],
model
:
str
,
num_experts
:
int
,
topk
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
mkn
:
tuple
[
int
,
int
,
int
],
):
label
=
"Quant Matmul"
sub_label
=
(
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})"
.
format
(
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
)
)
print
(
f
"Testing:
{
sub_label
}
"
)
(
m
,
k
,
n
)
=
mkn
dtype
=
torch
.
half
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
num_experts
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
num_experts
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
_
,
a_scale
=
ops
.
scaled_fp8_quant
(
a
)
w1_q
=
torch
.
empty
(
(
num_experts
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w2_q
=
torch
.
empty
((
num_experts
,
k
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w1_scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
num_experts
):
w1_q
[
expert
],
w1_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w1
[
expert
])
w2_q
[
expert
],
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w2
[
expert
])
score
=
torch
.
randn
((
m
,
num_experts
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
False
)
def
run_triton_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
,
num_repeats
:
int
,
):
for
_
in
range
(
num_repeats
):
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_scale
,
)
def
run_cutlass_moe
(
a
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
per_act_token
:
bool
,
num_repeats
:
int
,
):
for
_
in
range
(
num_repeats
):
cutlass_moe_fp8
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
w1_scale
,
w2_scale
,
per_act_token
,
a1_scale
=
None
,
)
def
run_cutlass_from_graph
(
a
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
,
w1_q
:
torch
.
Tensor
,
w2_q
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
):
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
):
return
cutlass_moe_fp8
(
a
,
w1_q
,
w2_q
,
topk_weights
,
topk_ids
,
w1_scale
,
w2_scale
,
per_act_token
,
a1_scale
=
None
,
)
def
run_triton_from_graph
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
,
):
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
):
return
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_scale
,
)
def
replay_graph
(
graph
,
num_repeats
):
for
_
in
range
(
num_repeats
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
cutlass_stream
=
torch
.
cuda
.
Stream
()
cutlass_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
cutlass_graph
,
stream
=
cutlass_stream
):
run_cutlass_from_graph
(
a
,
a_scale
,
w1_q
,
w2_q
,
w1_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
)
torch
.
cuda
.
synchronize
()
triton_stream
=
torch
.
cuda
.
Stream
()
triton_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
triton_graph
,
stream
=
triton_stream
):
run_triton_from_graph
(
a
,
w1_q
,
w2_q
,
topk_weights
,
topk_ids
,
w1_scale
,
w2_scale
,
a_scale
,
)
torch
.
cuda
.
synchronize
()
min_run_time
=
5
num_warmup
=
5
num_runs
=
25
globals
=
{
# Baseline params
"w1"
:
w1
,
"w2"
:
w2
,
"score"
:
score
,
"topk"
:
topk
,
# Cutlass params
"a_scale"
:
a_scale
,
"w1_q"
:
w1_q
,
"w2_q"
:
w2_q
,
"w1_scale"
:
w1_scale
,
"w2_scale"
:
w2_scale
,
"per_act_token"
:
per_act_token
,
# cuda graph params
"cutlass_graph"
:
cutlass_graph
,
"triton_graph"
:
triton_graph
,
# Gen params
"a"
:
a
,
"topk_weights"
:
topk_weights
,
"topk_ids"
:
topk_ids
,
"num_runs"
:
num_runs
,
# Kernels
"run_triton_moe"
:
run_triton_moe
,
"run_cutlass_moe"
:
run_cutlass_moe
,
"replay_graph"
:
replay_graph
,
}
# Warmup
run_triton_moe
(
a
,
w1_q
,
w2_q
,
topk_weights
,
topk_ids
,
w1_scale
,
w2_scale
,
a_scale
,
num_warmup
,
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
replay_graph
(
triton_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(triton_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
run_cutlass_moe
(
a
,
a_scale
,
w1_q
,
w2_q
,
w1_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
per_act_token
,
num_warmup
,
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"grouped_gemm_moe"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
# Warmup
replay_graph
(
cutlass_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(cutlass_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"grouped_gemm_moe_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
def
main
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
results
:
list
[
benchmark
.
Measurement
]
=
[]
for
model
in
args
.
models
:
for
tp
in
args
.
tp_sizes
:
for
layer
in
WEIGHT_SHAPES_MOE
[
model
]:
num_experts
=
layer
[
0
]
topk
=
layer
[
1
]
size_k
=
layer
[
2
]
size_n
=
layer
[
3
]
//
tp
if
len
(
args
.
limit_k
)
>
0
and
size_k
not
in
args
.
limit_k
:
continue
if
len
(
args
.
limit_n
)
>
0
and
size_n
not
in
args
.
limit_n
:
continue
for
per_act_token
in
PER_ACT_TOKEN_OPTS
:
for
per_out_ch
in
PER_OUT_CH_OPTS
:
for
size_m
in
DEFAULT_BATCH_SIZES
:
mkn
=
(
size_m
,
size_k
,
size_n
)
bench_run
(
results
,
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
,
)
compare
=
benchmark
.
Compare
(
results
)
compare
.
print
()
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark Marlin across specified models/shapes/batches"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES_MOE
.
keys
(),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-num-groups"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-act-token"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-out-ch"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_layernorm.py
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
torch
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.
util
s
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
seed_everything
)
from
vllm.
platform
s
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
@
torch
.
inference_mode
()
def
main
(
num_tokens
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
do_profile
:
bool
=
False
,
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
)
->
None
:
seed_everything
(
seed
)
def
main
(
num_tokens
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
do_profile
:
bool
=
False
,
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
"cuda"
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
)
...
...
@@ -38,7 +43,7 @@ def main(num_tokens: int,
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerSt
art
()
torch
.
cuda
.
cudart
().
cudaProfilerSt
op
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
...
...
@@ -54,33 +59,35 @@ def main(num_tokens: int,
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'__main__'
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the layernorm kernel."
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the layernorm kernel."
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--add-residual"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-warmup-iters"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-iters"
,
type
=
int
,
default
=
100
,
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
)
parser
.
add_argument
(
"--num-iters"
,
type
=
int
,
default
=
100
,
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
,
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
num_tokens
=
args
.
num_tokens
,
hidden_size
=
args
.
hidden_size
,
add_residual
=
args
.
add_residual
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
)
main
(
num_tokens
=
args
.
num_tokens
,
hidden_size
=
args
.
hidden_size
,
add_residual
=
args
.
add_residual
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
,
)
benchmarks/kernels/benchmark_lora.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
json
import
pickle
import
time
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
itertools
import
product
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Optional
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
utils
import
ArgPool
,
Bench
,
CudaGraphBenchParams
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.lora.ops.triton_ops
import
LoRAKernelMeta
,
lora_expand
,
lora_shrink
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_TP_SIZES
=
[
1
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
192
,
256
,
320
,
384
,
448
,
512
,
640
,
768
,
896
,
1024
,
2048
,
3072
,
4096
,
5120
,
6144
,
7168
,
8192
,
]
DEFAULT_HIDDEN_SIZES
=
[
1024
,
2048
,
4096
,
8192
,
16384
]
DEFAULT_LORA_RANKS
=
[
16
]
DEFAULT_NUM_LORAS
=
[
1
,
2
,
3
,
4
]
DEFAULT_SORT_BY_LORA_IDS
=
[
False
,
True
]
DEFAULT_SEQ_LENGTHS
=
[
1
]
DEFAULT_EXPAND_FN_ADD_INPUTS
=
[
True
,
False
]
# Utilities
def
dtype_to_str
(
dtype
:
torch
.
dtype
):
if
dtype
==
torch
.
float16
:
return
"f16"
if
dtype
==
torch
.
bfloat16
:
return
"bf16"
if
dtype
==
torch
.
float32
:
return
"f32"
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
"
)
def
make_rand_lora_weight_tensor
(
k
:
int
,
n
:
int
,
num_loras
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
=
"cuda"
)
->
torch
.
Tensor
:
# LoRA weights column major
return
torch
.
rand
((
num_loras
,
n
,
k
),
dtype
=
dtype
).
to
(
device
)
def
make_rand_tensors
(
a_shape
:
tuple
[
int
],
b_shape
:
tuple
[
int
],
c_shape
:
tuple
[
int
],
a_dtype
:
torch
.
dtype
,
b_dtype
:
torch
.
dtype
,
c_dtype
:
torch
.
dtype
,
num_slices
:
int
,
device
:
str
=
"cuda"
,
)
->
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
"""
Make LoRA input/output matrices.
"""
A
=
torch
.
rand
(
a_shape
,
dtype
=
a_dtype
).
to
(
device
)
# LoRA weights column major
Bs
=
[
torch
.
rand
(
b_shape
,
dtype
=
b_dtype
).
to
(
device
)
for
_
in
range
(
num_slices
)]
C
=
torch
.
zeros
(
c_shape
,
dtype
=
c_dtype
).
to
(
device
)
return
A
,
Bs
,
C
def
make_prompt_lora_mapping
(
num_prompts
:
int
,
num_active_loras
:
int
,
sort_by_lora_id
:
bool
,
device
:
str
)
->
torch
.
Tensor
:
"""
All prompts are mapped to a LoRA ID in range [0, num_active_loras).
where 0 refers to first lora, 1 refers to second lora and so on.
"""
assert
num_active_loras
>
0
if
not
sort_by_lora_id
:
return
torch
.
randint
(
0
,
num_active_loras
,
(
num_prompts
,),
dtype
=
torch
.
long
)
# Divide LoRAs equally and in order.
part_size
=
num_prompts
//
num_active_loras
part_size
=
max
(
part_size
,
1
)
lora_id
=
0
prompt_lora_mapping
=
[]
while
len
(
prompt_lora_mapping
)
<
num_prompts
:
prompt_lora_mapping
.
extend
([
lora_id
]
*
part_size
)
lora_id
=
lora_id
+
1
if
lora_id
+
1
<
num_active_loras
else
lora_id
return
torch
.
tensor
(
prompt_lora_mapping
[:
num_prompts
],
dtype
=
torch
.
long
,
device
=
device
)
def
make_token_lora_mapping
(
num_tokens
:
int
,
num_prompts
:
int
,
prompt_lora_mapping
:
torch
.
Tensor
,
seq_len_tensor
:
torch
.
Tensor
,
device
:
str
,
):
"""
Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor
"""
assert
prompt_lora_mapping
.
shape
[
0
]
==
num_prompts
# token to lora index mapping
token_lora_mapping
=
[
0
]
*
num_tokens
current_offset
=
0
for
b_id
in
range
(
num_prompts
):
lora_index
=
prompt_lora_mapping
[
b_id
].
item
()
s
=
current_offset
e
=
s
+
seq_len_tensor
[
b_id
].
item
()
token_lora_mapping
[
s
:
e
]
=
[
lora_index
]
*
(
e
-
s
)
current_offset
+=
seq_len_tensor
[
b_id
].
item
()
return
torch
.
tensor
(
token_lora_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
def
ref_group_gemm
(
ref_out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
lora_weights
:
list
[
torch
.
Tensor
],
seq_lens_cpu
:
torch
.
Tensor
,
prompt_lora_mapping_cpu
:
torch
.
Tensor
,
scaling
:
float
,
add_inputs
:
Optional
[
bool
],
):
"""
Torch group gemm reference implementation to test correctness of
benchmarking operations.
"""
batches
=
seq_lens_cpu
.
size
(
0
)
out_list
=
[]
current_offset
=
0
for
lora_index
,
b_length
in
zip
(
range
(
batches
),
seq_lens_cpu
):
x
=
input
[
current_offset
:
b_length
+
current_offset
,
:]
current_offset
+=
b_length
w
=
lora_weights
[
prompt_lora_mapping_cpu
[
lora_index
]]
result
=
torch
.
nn
.
functional
.
linear
(
x
,
w
)
result
*=
scaling
out_list
.
append
(
result
)
cat_result
=
torch
.
cat
(
out_list
,
dim
=
0
)
if
add_inputs
:
ref_out
+=
cat_result
else
:
ref_out
.
copy_
(
cat_result
)
class
OpType
(
Enum
):
"""
LoRA Ops to benchmark and its properties.
"""
LORA_SHRINK
=
auto
()
LORA_EXPAND
=
auto
()
@
staticmethod
def
from_str
(
s
:
str
)
->
"OpType"
:
if
s
.
lower
()
==
"lora_shrink"
:
return
OpType
.
LORA_SHRINK
if
s
.
lower
()
==
"lora_expand"
:
return
OpType
.
LORA_EXPAND
raise
ValueError
(
f
"Unrecognized str
{
s
}
to convert to OpType"
)
def
is_shrink_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
LORA_SHRINK
]
def
is_expand_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
LORA_EXPAND
]
def
num_slices
(
self
)
->
list
[
int
]:
return
[
1
,
2
,
3
]
def
mkn
(
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
)
->
tuple
[
int
,
int
,
int
]:
num_tokens
=
batch_size
*
seq_length
if
self
.
is_shrink_fn
():
m
=
num_tokens
k
=
hidden_size
n
=
lora_rank
else
:
assert
self
.
is_expand_fn
()
m
=
num_tokens
k
=
lora_rank
n
=
hidden_size
return
m
,
k
,
n
def
matmul_dtypes
(
self
,
op_dtype
:
torch
.
dtype
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
,
torch
.
dtype
]:
"""
return a type, b type and c type for A x B = C
"""
if
self
.
is_shrink_fn
():
return
op_dtype
,
op_dtype
,
torch
.
float32
else
:
assert
self
.
is_expand_fn
()
return
torch
.
float32
,
op_dtype
,
op_dtype
def
matmul_shapes
(
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
,
num_loras
:
int
,
num_slices
:
int
,
)
->
tuple
[
tuple
[
int
],
tuple
[
int
],
tuple
[
int
]]:
"""
Given num_slices, return the shapes of the A, B, and C matrices
in A x B = C, for the op_type
"""
m
,
k
,
n
=
self
.
mkn
(
batch_size
,
seq_length
,
hidden_size
,
lora_rank
)
b_shape
=
(
num_loras
,
n
,
k
)
# col-major
if
self
in
[
OpType
.
LORA_SHRINK
]:
# LoRA shrink kernels support num_slices inherently in the kernel.
return
((
m
,
k
),
b_shape
,
(
num_slices
,
m
,
n
))
if
self
in
[
OpType
.
LORA_EXPAND
]:
# LoRA expand kernels support num_slices inherently in the kernel
return
((
num_slices
,
m
,
k
),
b_shape
,
(
m
,
n
*
num_slices
))
raise
ValueError
(
f
"Unrecognized op_type
{
self
}
"
)
def
bench_fn
(
self
)
->
Callable
:
if
self
==
OpType
.
LORA_SHRINK
:
return
lora_shrink
if
self
==
OpType
.
LORA_EXPAND
:
return
lora_expand
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
def
run_ref_group_gemm
(
self
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
lora_weights
:
list
[
torch
.
Tensor
],
**
kwargs
,
)
->
Callable
:
"""Each benchmark operation expects the input, lora_weights and outputs
in a slightly different format. Refer to self.matmul_shapes().
run_ref_group_gemm accounts for those differences in executing a
reference group gemm for correctness testing.
"""
w_dtype
=
lora_weights
[
0
].
dtype
num_slices
=
len
(
lora_weights
)
if
self
in
[
OpType
.
LORA_SHRINK
]:
for
slice_idx
in
range
(
num_slices
):
ref_group_gemm
(
ref_out
=
output
[
slice_idx
,
:],
input
=
input
,
lora_weights
=
lora_weights
[
slice_idx
],
**
kwargs
,
)
elif
self
in
[
OpType
.
LORA_EXPAND
]:
hidden_size
=
lora_weights
[
0
].
shape
[
1
]
for
slice_idx
in
range
(
num_slices
):
slice_offset
=
slice_idx
*
hidden_size
ref_group_gemm
(
ref_out
=
output
[:,
slice_offset
:
slice_offset
+
hidden_size
],
input
=
input
[
slice_idx
].
clone
().
to
(
dtype
=
w_dtype
),
lora_weights
=
lora_weights
[
slice_idx
],
**
kwargs
,
)
else
:
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
@
dataclass
class
BenchmarkContext
:
"""
LoRA benchmark context
"""
batch_size
:
int
hidden_size
:
int
num_loras
:
int
num_active_loras
:
int
lora_rank
:
int
sort_by_lora_id
:
bool
dtype
:
torch
.
dtype
seq_length
:
Optional
[
int
]
=
None
num_slices
:
Optional
[
int
]
=
None
# num_slices for slice based ops
def
with_seq_length
(
self
,
seq_length
:
int
)
->
"BenchmarkContext"
:
ctx
=
copy
.
copy
(
self
)
ctx
.
seq_length
=
seq_length
return
ctx
def
with_num_slices
(
self
,
num_slices
:
int
)
->
"BenchmarkContext"
:
ctx
=
copy
.
copy
(
self
)
ctx
.
num_slices
=
num_slices
return
ctx
def
bench_label
(
self
)
->
str
:
return
f
"lora-
{
self
.
dtype
}
"
def
bench_sublabel
(
self
,
op_type
:
OpType
)
->
str
:
m
,
k
,
n
=
op_type
.
mkn
(
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
,
self
.
lora_rank
)
desc
=
{
"bs"
:
self
.
batch_size
,
"sl"
:
self
.
seq_length
,
"m"
:
m
,
"k"
:
k
,
"n"
:
n
,
"num_loras"
:
self
.
num_loras
,
"sort_by_lora"
:
self
.
sort_by_lora_id
,
"num_slices"
:
self
.
num_slices
,
}
return
json
.
dumps
(
desc
)
@
dataclass
class
BenchmarkTensors
:
"""
Input/Output tensors used for benchmarks
"""
# matmul tensors
input
:
torch
.
Tensor
lora_weights_lst
:
list
[
torch
.
Tensor
]
output
:
torch
.
Tensor
# LoRA kernel metadata
lora_kernel_meta
:
LoRAKernelMeta
# Metadata tensors used in testing correctness
seq_lens
:
torch
.
Tensor
prompt_lora_mapping
:
torch
.
Tensor
def
io_types
(
self
)
->
str
:
return
(
f
"
{
dtype_to_str
(
self
.
input
.
dtype
)
}
x"
f
"
{
dtype_to_str
(
self
.
lora_weights_lst
[
0
].
dtype
)
}
=>"
f
"
{
dtype_to_str
(
self
.
output
.
dtype
)
}
"
)
@
staticmethod
def
make
(
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
device
:
str
=
"cuda"
)
->
"BenchmarkTensors"
:
# Make input / output matmul tensors.
a_shape
,
b_shape
,
c_shape
=
op_type
.
matmul_shapes
(
ctx
.
batch_size
,
ctx
.
seq_length
,
ctx
.
hidden_size
,
ctx
.
lora_rank
,
ctx
.
num_loras
,
ctx
.
num_slices
,
)
a_type
,
b_type
,
c_type
=
op_type
.
matmul_dtypes
(
ctx
.
dtype
)
input_tensor
,
lora_weights
,
output_tensor
=
make_rand_tensors
(
a_shape
,
b_shape
,
c_shape
,
a_type
,
b_type
,
c_type
,
num_slices
=
ctx
.
num_slices
)
# Make metadata tensors.
# Keep the metadata tensors in the CPU for further processing if needed.
# The tensors get moved to the GPU before benchmarking.
assert
ctx
.
num_active_loras
<=
ctx
.
num_loras
total_tokens
=
ctx
.
batch_size
*
ctx
.
seq_length
# Make metadata tensors involved in correctness testing.
# Prepare seq lens tensor
seq_len_tensor
=
torch
.
randint
(
ctx
.
seq_length
,
ctx
.
seq_length
+
1
,
(
ctx
.
batch_size
,)
)
assert
total_tokens
==
seq_len_tensor
.
sum
()
# Prepare prompt lora indices tensor
prompt_lora_indices_tensor
=
make_prompt_lora_mapping
(
ctx
.
batch_size
,
ctx
.
num_active_loras
,
ctx
.
sort_by_lora_id
,
"cpu"
)
# Make LoRAKernelMeta
token_lora_indices_tensor
=
make_token_lora_mapping
(
total_tokens
,
ctx
.
batch_size
,
prompt_lora_indices_tensor
,
seq_len_tensor
,
"cpu"
,
)
lora_kernel_meta
=
LoRAKernelMeta
.
make
(
max_loras
=
ctx
.
num_loras
,
max_num_tokens
=
token_lora_indices_tensor
.
size
(
0
),
device
=
"cpu"
,
)
lora_kernel_meta
.
prepare_tensors
(
token_lora_mapping
=
token_lora_indices_tensor
)
return
BenchmarkTensors
(
input_tensor
,
lora_weights
,
output_tensor
,
lora_kernel_meta
,
seq_len_tensor
,
prompt_lora_indices_tensor
,
)
def
sanity_check
(
self
)
->
None
:
"""
Fails asserts when non-conformality is detected.
"""
num_tokens
=
self
.
input
.
shape
[
-
2
]
# check metadata tensors
assert
torch
.
sum
(
self
.
seq_lens
)
==
num_tokens
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
# assert self.seq_start_loc.shape[0] == num_seqs
assert
self
.
prompt_lora_mapping
.
shape
[
0
]
==
num_seqs
assert
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
]
==
num_tokens
def
to_device
(
self
,
device
:
str
):
"""
Transfer tensors to device if the tensors aren't already on the device
"""
def
to_device
(
tensor
:
torch
.
Tensor
):
if
tensor
.
device
!=
device
:
tensor
=
tensor
.
to
(
device
=
device
)
return
tensor
self
.
input
=
to_device
(
self
.
input
)
self
.
output
=
to_device
(
self
.
output
)
self
.
seq_lens
=
to_device
(
self
.
seq_lens
)
self
.
prompt_lora_mapping
=
to_device
(
self
.
prompt_lora_mapping
)
for
i
in
range
(
len
(
self
.
lora_weights_lst
)):
self
.
lora_weights_lst
[
i
]
=
to_device
(
self
.
lora_weights_lst
[
i
])
# LoRA meta
for
field_name
in
LoRAKernelMeta
.
__dataclass_fields__
:
field
=
getattr
(
self
.
lora_kernel_meta
,
field_name
)
assert
isinstance
(
field
,
torch
.
Tensor
)
setattr
(
self
.
lora_kernel_meta
,
field_name
,
to_device
(
field
))
def
metadata
(
self
)
->
tuple
[
int
,
int
,
int
]:
"""
Return num_seqs, num_tokens and max_seq_len
"""
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
num_tokens
=
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
]
max_seq_len
=
torch
.
max
(
self
.
seq_lens
).
item
()
num_slices
=
len
(
self
.
lora_weights_lst
)
return
num_seqs
,
num_tokens
,
max_seq_len
,
num_slices
def
as_lora_shrink_kwargs
(
self
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
()
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
()
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape [num_tokens, hidden_size]
assert
len
(
i_shape
)
==
2
assert
i_shape
[
0
]
==
num_tokens
hidden_size
=
i_shape
[
1
]
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
assert
len
(
lw_shape
)
==
3
assert
lw_shape
[
2
]
==
hidden_size
lora_rank
=
lw_shape
[
1
]
# Expected output shape [num_slices, num_tokens, lora_rank]
assert
len
(
o_shape
)
==
3
assert
o_shape
==
(
num_slices
,
num_tokens
,
lora_rank
)
return
{
"inputs"
:
self
.
input
,
"lora_a_weights"
:
self
.
lora_weights_lst
,
"output_tensor"
:
self
.
output
,
"token_lora_mapping"
:
self
.
lora_kernel_meta
.
token_lora_mapping
,
"token_indices_sorted_by_lora_ids"
:
(
self
.
lora_kernel_meta
.
token_indices_sorted_by_lora_ids
),
"num_tokens_per_lora"
:
self
.
lora_kernel_meta
.
num_tokens_per_lora
,
"lora_token_start_loc"
:
self
.
lora_kernel_meta
.
lora_token_start_loc
,
"lora_ids"
:
self
.
lora_kernel_meta
.
active_lora_ids
,
"scaling"
:
1.0
,
}
def
as_lora_expand_kwargs
(
self
,
add_inputs
:
bool
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
()
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
()
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape : [num_slices, num_tokens, lora_rank]
assert
len
(
i_shape
)
==
3
assert
i_shape
[
0
]
==
num_slices
assert
i_shape
[
1
]
==
num_tokens
lora_rank
=
i_shape
[
2
]
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
assert
len
(
lw_shape
)
==
3
assert
lw_shape
[
2
]
==
lora_rank
hidden_size
=
lw_shape
[
1
]
# Expected output shape : [num_tokens, hidden_size * num_slices]
assert
len
(
o_shape
)
==
2
assert
o_shape
==
(
num_tokens
,
hidden_size
*
num_slices
)
return
{
"inputs"
:
self
.
input
,
"lora_b_weights"
:
self
.
lora_weights_lst
,
"output_tensor"
:
self
.
output
,
"token_lora_mapping"
:
self
.
lora_kernel_meta
.
token_lora_mapping
,
"token_indices_sorted_by_lora_ids"
:
(
self
.
lora_kernel_meta
.
token_indices_sorted_by_lora_ids
),
"num_tokens_per_lora"
:
self
.
lora_kernel_meta
.
num_tokens_per_lora
,
"lora_token_start_loc"
:
self
.
lora_kernel_meta
.
lora_token_start_loc
,
"lora_ids"
:
self
.
lora_kernel_meta
.
active_lora_ids
,
"offset_start"
:
0
,
"add_inputs"
:
add_inputs
,
}
def
bench_fn_kwargs
(
self
,
op_type
:
OpType
,
add_inputs
:
Optional
[
bool
]
=
None
)
->
dict
[
str
,
Any
]:
if
op_type
.
is_shrink_fn
():
assert
add_inputs
is
None
else
:
assert
add_inputs
is
not
None
if
op_type
==
OpType
.
LORA_SHRINK
:
return
self
.
as_lora_shrink_kwargs
()
if
op_type
==
OpType
.
LORA_EXPAND
:
return
self
.
as_lora_expand_kwargs
(
add_inputs
)
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
def
test_correctness
(
self
,
op_type
:
OpType
,
expand_fn_add_inputs
:
Optional
[
bool
]
)
->
bool
:
"""
Test correctness of op_type implementation against a grouped gemm
reference implementation.
"""
seq_lens_cpu
=
self
.
seq_lens
.
to
(
device
=
"cpu"
)
prompt_lora_mapping_cpu
=
self
.
prompt_lora_mapping
.
to
(
device
=
"cpu"
)
ref_output
=
self
.
output
.
clone
()
self
.
output
.
zero_
()
op_type
.
bench_fn
()(
**
self
.
bench_fn_kwargs
(
op_type
,
expand_fn_add_inputs
))
op_type
.
run_ref_group_gemm
(
ref_output
,
self
.
input
,
self
.
lora_weights_lst
,
seq_lens_cpu
=
seq_lens_cpu
,
prompt_lora_mapping_cpu
=
prompt_lora_mapping_cpu
,
scaling
=
1.0
,
add_inputs
=
expand_fn_add_inputs
,
)
rtol
,
atol
=
{
torch
.
float16
:
(
6e-2
,
6e-2
),
torch
.
bfloat16
:
(
6e-2
,
6e-2
),
torch
.
float32
:
(
1e-2
,
1e-2
),
}[
self
.
output
.
dtype
]
return
torch
.
allclose
(
ref_output
,
self
.
output
,
rtol
=
rtol
,
atol
=
atol
)
def
bench_optype
(
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
op_type
:
OpType
,
cuda_graph_nops
:
Optional
[
int
]
=
None
,
expand_fn_add_inputs
:
Optional
[
bool
]
=
None
,
test_correctness
:
bool
=
False
,
)
->
TMeasurement
:
assert
arg_pool_size
>=
1
if
op_type
.
is_shrink_fn
():
assert
expand_fn_add_inputs
is
None
else
:
assert
expand_fn_add_inputs
is
not
None
# BenchmarkContext -> BenchmarkTensors
bench_tensors
:
list
[
BenchmarkTensors
]
=
[
BenchmarkTensors
.
make
(
ctx
,
op_type
)
for
_
in
range
(
arg_pool_size
)
]
for
bt
in
bench_tensors
:
bt
.
sanity_check
()
# Test correctness of our implementation.
if
test_correctness
:
assert
all
(
[
bt
.
test_correctness
(
op_type
,
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
)
# BenchmarkTensors -> dict (kwargs)
kwargs_list
=
[
bt
.
bench_fn_kwargs
(
op_type
,
add_inputs
=
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
# Clear LoRA optimization hash-maps.
_LORA_A_PTR_DICT
.
clear
()
_LORA_B_PTR_DICT
.
clear
()
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup
for
kwargs
in
kwargs_list
:
op_type
.
bench_fn
()(
**
kwargs
)
torch
.
cuda
.
synchronize
()
# Merge into a single kwargs and qualify arguments as ArgPool
kwargs
=
{
k
:
ArgPool
([])
for
k
in
kwargs_list
[
0
]}
for
_kwargs
in
kwargs_list
:
for
k
,
v
in
_kwargs
.
items
():
kwargs
[
k
].
values
.
append
(
v
)
describe_args
=
(
f
"add_inputs=
{
expand_fn_add_inputs
}
"
if
expand_fn_add_inputs
is
not
None
else
""
)
description
=
f
"
{
op_type
.
name
}
(
{
describe_args
}
) (
{
bench_tensors
[
0
].
io_types
()
}
)"
cuda_graph_params
=
None
if
cuda_graph_nops
:
cuda_graph_params
=
CudaGraphBenchParams
(
cuda_graph_nops
)
timer
=
None
with
Bench
(
cuda_graph_params
,
ctx
.
bench_label
(),
ctx
.
bench_sublabel
(
op_type
),
description
,
op_type
.
bench_fn
(),
**
kwargs
,
)
as
bench
:
timer
=
bench
.
run
()
return
timer
def
bench_torch_mm
(
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
op_type
:
OpType
,
cuda_graph_nops
:
Optional
[
int
]
=
None
,
)
->
TMeasurement
:
"""
Benchmark basic torch.mm as a roofline.
When all the input tokens have the same LoRA ID, the LoRA kernels are just
a matmul. This torch.mm benchmark serves as a roofline for that case.
input op_type is used in determining the m, k, n dimensions for the matmul.
"""
batch_size
,
hidden_size
,
lora_rank
,
seq_length
,
dtype
=
(
ctx
.
batch_size
,
ctx
.
hidden_size
,
ctx
.
lora_rank
,
ctx
.
seq_length
,
ctx
.
dtype
,
)
m
,
k
,
n
=
op_type
.
mkn
(
batch_size
,
seq_length
,
hidden_size
,
lora_rank
)
# For a fairer comparison.
n
=
n
*
ctx
.
num_slices
# Get matmul input and output tensors for A x B = C
As
,
Bs
,
Cs
=
[],
[],
[]
for
_
in
range
(
arg_pool_size
):
As
.
append
(
torch
.
rand
((
m
,
k
),
dtype
=
dtype
).
to
(
"cuda"
))
Bs
.
append
(
torch
.
rand
((
n
,
k
),
dtype
=
dtype
).
to
(
"cuda"
).
t
())
Cs
.
append
(
torch
.
rand
((
m
,
n
),
dtype
=
dtype
).
to
(
"cuda"
))
# Make torch.mm kwargs
mm_kwargs
=
{
"input"
:
ArgPool
(
As
),
"mat2"
:
ArgPool
(
Bs
),
"out"
:
ArgPool
(
Cs
)}
description
=
(
f
"single-lora roofline using torch.mm (
{
dtype_to_str
(
dtype
)
}
"
f
"x
{
dtype_to_str
(
dtype
)
}
"
f
"=>
{
dtype_to_str
(
dtype
)
}
)"
)
cuda_graph_params
=
None
if
cuda_graph_nops
:
cuda_graph_params
=
CudaGraphBenchParams
(
cuda_graph_nops
)
with
Bench
(
cuda_graph_params
,
ctx
.
bench_label
(),
ctx
.
bench_sublabel
(
op_type
),
description
,
torch
.
mm
,
**
mm_kwargs
,
)
as
bench
:
return
bench
.
run
()
# runner
def
use_cuda_graph_recommendation
()
->
str
:
return
"""
Triton kernels have a significant launch overhead with
launched directly via python. This overhead is more noticeable
for small the problem sizes. For these cases, it is recommended
to use the script with `--cuda-graph-nops N` to benchmark N
consecutive invocations of the benchmarking operations from
inside a CUDA Graph. Note that the returned measurement is for N
invocations of the operation.
"""
def
print_timers
(
timers
:
list
[
TMeasurement
],
args
:
Optional
[
argparse
.
Namespace
]
=
None
):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
if
args
and
args
.
cuda_graph_nops
:
print
(
f
"Note : The timings reported above is for
{
args
.
cuda_graph_nops
}
"
"consecutive invocations of the benchmarking functions. "
f
"Please divide by
{
args
.
cuda_graph_nops
}
for single invocation "
"timings."
)
print
(
"Note on Comparison with torch.mm : The torch.mm numbers are "
"benchmark numbers of a simple matmul emulating the single lora "
"case. It is provided as a roofline for comparing our LoRA Kernel "
"implementations. It is expected that the LoRA kernels will be "
"slower than torch.mm in cases where num_loras is big. But for "
"small num_loras the goal should be to match the torch.mm numbers."
)
def
run
(
args
:
argparse
.
Namespace
,
bench_ctxs
:
list
[
BenchmarkContext
]):
if
args
.
cuda_graph_nops
is
not
None
:
assert
args
.
cuda_graph_nops
>
0
print
(
f
"Benchmarking
{
args
.
cuda_graph_nops
}
invocations inside a CUDA Graph"
)
else
:
print
(
f
"CUDA Graphs not enabled.
\n
{
use_cuda_graph_recommendation
()
}
"
)
timers
=
[]
for
bench_ctx
in
bench_ctxs
:
for
seq_len
in
args
.
seq_lengths
:
bench_ops
:
list
[
OpType
]
=
args
.
op_types
seq_len_timers
=
[]
for
bench_op
in
bench_ops
:
for
num_slices
in
bench_op
.
num_slices
():
_ctx
=
bench_ctx
.
with_seq_length
(
seq_len
).
with_num_slices
(
num_slices
)
# Benchmark torch.mm as a roofline
seq_len_timers
.
append
(
bench_torch_mm
(
_ctx
,
args
.
arg_pool_size
,
bench_op
,
args
.
cuda_graph_nops
)
)
# Benchmark bench_op
expand_fn_add_inputs
=
(
[
None
]
if
bench_op
.
is_shrink_fn
()
else
args
.
expand_fn_add_inputs
)
for
add_input_arg
in
expand_fn_add_inputs
:
seq_len_timers
.
append
(
bench_optype
(
_ctx
,
args
.
arg_pool_size
,
bench_op
,
args
.
cuda_graph_nops
,
add_input_arg
,
args
.
test_correctness
,
)
)
print_timers
(
seq_len_timers
)
timers
.
extend
(
seq_len_timers
)
# Result stdout dump
print
(
"== All Results ===="
)
print_timers
(
timers
,
args
)
if
args
.
output_directory
:
# Result file dump
od
=
Path
(
args
.
output_directory
)
if
not
od
.
exists
():
od
.
mkdir
()
timestamp
=
int
(
time
.
time
())
pkl_file
=
od
/
f
"lora_bench-
{
timestamp
}
.pkl"
print
(
f
"Writing benchmarks to
{
pkl_file
}
"
)
with
open
(
pkl_file
,
"wb"
)
as
f
:
pickle
.
dump
(
timers
,
f
)
def
as_benchmark_contexts
(
hidden_sizes
:
list
[
int
],
lora_ranks
:
list
[
int
],
args
:
argparse
.
Namespace
)
->
list
[
BenchmarkContext
]:
ctxs
:
list
[
BenchmarkContext
]
=
[]
for
batch_size
,
hidden_size
,
lora_rank
,
num_loras
,
sort_by_lora_id
in
product
(
# noqa
args
.
batch_sizes
,
list
(
hidden_sizes
),
lora_ranks
,
args
.
num_loras
,
args
.
sort_by_lora_id
,
):
ctxs
.
append
(
BenchmarkContext
(
batch_size
=
batch_size
,
hidden_size
=
hidden_size
,
lora_rank
=
lora_rank
,
num_loras
=
num_loras
,
num_active_loras
=
args
.
num_active_loras
if
args
.
num_active_loras
else
num_loras
,
# To be filled based on the OpType to benchmark
seq_length
=
None
,
sort_by_lora_id
=
sort_by_lora_id
,
dtype
=
args
.
dtype
,
# To be filled based on the OpType to benchmark
num_slices
=
None
,
)
)
return
ctxs
def
run_list_bench
(
args
:
argparse
.
Namespace
):
print
(
args
)
print
(
"List bench :
\n
"
f
" Hidden Sizes
{
args
.
hidden_sizes
}
"
f
" LoRA Ranks
{
args
.
lora_ranks
}
"
)
# Get all benchmarking contexts
bench_contexts
:
list
[
BenchmarkContext
]
=
as_benchmark_contexts
(
hidden_sizes
=
args
.
hidden_sizes
,
lora_ranks
=
args
.
lora_ranks
,
args
=
args
)
run
(
args
,
bench_contexts
)
def
run_range_bench
(
args
:
argparse
.
Namespace
):
print
(
args
)
hidden_sizes
=
list
(
range
(
args
.
hidden_sizes_start
,
args
.
hidden_sizes_end
+
1
,
args
.
hidden_sizes_increment
,
)
)
lora_ranks
=
list
(
range
(
args
.
lora_ranks_start
,
args
.
lora_ranks_end
+
1
,
args
.
lora_ranks_increment
)
)
print
(
f
"Range bench :
\n
Hidden Sizes
{
hidden_sizes
}
LoRA Ranks
{
lora_ranks
}
"
)
# Get all benchmarking contexts
bench_contexts
:
list
[
BenchmarkContext
]
=
as_benchmark_contexts
(
hidden_sizes
=
hidden_sizes
,
lora_ranks
=
lora_ranks
,
args
=
args
)
run
(
args
,
bench_contexts
)
def
run_model_bench
(
args
:
argparse
.
Namespace
):
print
(
args
)
def
hidden_sizes_from_model
(
model
:
str
,
tp_size
:
int
)
->
set
[
int
]:
hidden_sizes
=
set
()
for
KN
,
tp_split_dim
in
WEIGHT_SHAPES
[
model
]:
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
hidden_sizes
.
add
(
KN
[
1
])
return
hidden_sizes
# Get all hidden sizes
hidden_sizes
:
set
[
int
]
=
set
()
for
model_name
,
tp_size
in
product
(
args
.
models
,
args
.
tp_sizes
):
hidden_sizes
=
hidden_sizes
.
union
(
hidden_sizes_from_model
(
model_name
,
tp_size
))
print
(
f
"Model bench :
\n
Hidden Sizes
{
hidden_sizes
}
LoRA Ranks
{
args
.
lora_ranks
}
"
)
# Get all benchmarking contexts
bench_contexts
:
list
[
BenchmarkContext
]
=
as_benchmark_contexts
(
hidden_sizes
=
hidden_sizes
,
lora_ranks
=
args
.
lora_ranks
,
args
=
args
)
run
(
args
,
bench_contexts
)
if
__name__
==
"__main__"
:
def
to_torch_dtype
(
dt
):
if
dt
==
"torch.float16"
:
return
torch
.
float16
if
dt
==
"torch.bfloat16"
:
return
torch
.
bfloat16
raise
ValueError
(
"unsupported dtype"
)
def
get_bool
(
s
:
str
)
->
bool
:
return
s
.
lower
()
in
[
"true"
,
"1"
]
def
add_common_command_args
(
p
:
argparse
.
ArgumentParser
):
p
.
add_argument
(
"--dtype"
,
type
=
to_torch_dtype
,
required
=
True
,
help
=
"Available options are ['torch.float16', 'torch.bfloat16']"
,
)
p
.
add_argument
(
"--arg-pool-size"
,
type
=
int
,
default
=
32
,
help
=
"Run profiles with a pool of input/output/meta tensors instead"
"of simply reusing the same tensors for all runs. A bigger arg-pool"
"mitigates hardware caching effects during benchmarking."
,
)
p
.
add_argument
(
"--cuda-graph-nops"
,
type
=
int
,
help
=
(
"when set profiling is done using cudagraph, "
"with the given number of operations in a graph."
"Note that the measurement returned is the time "
"taken for N consecutive executions of the benchmarking "
"functions, where N is the value of this argument."
),
)
p
.
add_argument
(
"--num-loras"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_NUM_LORAS
)
p
.
add_argument
(
"--num-active-loras"
,
type
=
int
,
default
=
None
,
help
=
"Active LoRAs. When None, all LoRAs are active"
,
)
p
.
add_argument
(
"--sort-by-lora-id"
,
nargs
=
"+"
,
type
=
get_bool
,
default
=
DEFAULT_SORT_BY_LORA_IDS
,
)
p
.
add_argument
(
"--op-types"
,
nargs
=
"+"
,
type
=
OpType
.
from_str
,
default
=
list
(
OpType
)
)
p
.
add_argument
(
"--seq-lengths"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_SEQ_LENGTHS
)
p
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
p
.
add_argument
(
"--expand-fn-add-inputs"
,
nargs
=
"+"
,
type
=
get_bool
,
default
=
DEFAULT_EXPAND_FN_ADD_INPUTS
,
)
p
.
add_argument
(
"-o"
,
"--output-directory"
,
type
=
str
,
help
=
(
"Output directory to store a the list of benchmarking"
"TMeasurement objects as a pickle file"
),
)
p
.
add_argument
(
"--test-correctness"
,
action
=
"store_true"
,
help
=
(
"When enabled, the benchmarking functions are tested"
"for correctness before the actual benchmarking"
),
)
parser
=
FlexibleArgumentParser
(
description
=
f
"""
Benchmark LoRA kernels:
{
use_cuda_graph_recommendation
()
}
list_bench example:
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
model_bench example:
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
range_bench example:
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
"""
,
# noqa: E501
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
,
required
=
True
)
list_parser
=
subparsers
.
add_parser
(
"list_bench"
)
list_parser
.
add_argument
(
"--hidden-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_HIDDEN_SIZES
)
list_parser
.
add_argument
(
"--lora-ranks"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_LORA_RANKS
)
add_common_command_args
(
list_parser
)
list_parser
.
set_defaults
(
func
=
run_list_bench
)
range_parser
=
subparsers
.
add_parser
(
"range_bench"
)
range_parser
.
add_argument
(
"--hidden-sizes-start"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--hidden-sizes-end"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--hidden-sizes-increment"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--lora-ranks-start"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--lora-ranks-end"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--lora-ranks-increment"
,
type
=
int
,
required
=
True
)
add_common_command_args
(
range_parser
)
range_parser
.
set_defaults
(
func
=
run_range_bench
)
model_parser
=
subparsers
.
add_parser
(
"model_bench"
)
model_parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
model_parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
model_parser
.
add_argument
(
"--lora-ranks"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_LORA_RANKS
)
add_common_command_args
(
model_parser
)
model_parser
.
set_defaults
(
func
=
run_model_bench
)
args
=
parser
.
parse_args
()
args
.
func
(
args
)
benchmarks/kernels/benchmark_machete.py
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
math
import
os
import
pickle
as
pkl
import
time
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
itertools
import
product
from
typing
import
Callable
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
Optional
import
pandas
as
pd
import
torch
...
...
@@ -15,11 +21,18 @@ from weight_shapes import WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
marlin_permute_scales
)
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
marlin_permute_scales
,
marlin_zero_points
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
)
MarlinWorkspace
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
pack_rows
,
quantize_weights
)
pack_rows
,
quantize_weights
,
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
...
...
@@ -27,149 +40,390 @@ DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
DEFAULT_TP_SIZES
=
[
1
]
NVTX_PROFILE
=
os
.
environ
.
get
(
"NVTX_PROFILE"
,
False
)
if
NVTX_PROFILE
:
import
nvtx
def
terse_type_name
(
dt
):
return
{
torch
.
bfloat16
:
"bf16"
,
torch
.
float16
:
"fp16"
,
torch
.
int8
:
"int8"
,
torch
.
float8_e4m3fn
:
"fp8"
,
torch
.
float
:
"float"
,
torch
.
int
:
"int"
,
}[
dt
]
@
dataclass
class
BenchmarkTensors
:
w_ref
:
torch
.
Tensor
a
:
torch
.
Tensor
w_q
:
torch
.
Tensor
group_size
:
Optional
[
int
]
wtype
:
ScalarType
w_g_s
:
torch
.
Tensor
w_g_zp
:
Optional
[
torch
.
Tensor
]
w_ch_s
:
Optional
[
torch
.
Tensor
]
w_tok_s
:
Optional
[
torch
.
Tensor
]
@
dataclass
class
TypeConfig
:
act_type
:
torch
.
dtype
weight_type
:
ScalarType
output_type
:
Optional
[
torch
.
dtype
]
group_scale_type
:
Optional
[
torch
.
dtype
]
group_zero_type
:
Optional
[
torch
.
dtype
]
channel_scale_type
:
Optional
[
torch
.
dtype
]
token_scale_type
:
Optional
[
torch
.
dtype
]
def
rand_data
(
shape
,
dtype
=
torch
.
float16
,
scale
=
1
):
if
dtype
.
is_floating_point
:
return
(
scale
*
torch
.
rand
(
shape
,
device
=
"cuda"
)
-
0.3
).
to
(
dtype
)
else
:
return
torch
.
randint
(
-
15
,
15
,
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
def
quantize_and_pack
(
atype
:
torch
.
dtype
,
w
:
torch
.
Tensor
,
wtype
:
ScalarType
,
stype
:
Optional
[
torch
.
dtype
],
group_size
:
Optional
[
int
],
zero_points
:
bool
=
False
,
):
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
w_ref
,
w_q
,
w_s
,
w_zp
=
quantize_weights
(
w
,
wtype
,
group_size
=
group_size
,
zero_points
=
zero_points
,
# to match how the kernel applies zps
ref_zero_points_after_scales
=
True
,
)
def
machete_pack_weights
(
w_q
:
torch
.
tensor
,
wtype
:
ScalarType
)
->
torch
.
tensor
:
w_q
=
pack_rows
(
w_q
,
wtype
.
size_bits
,
*
w_q
.
shape
)
w_q
=
w_q
.
t
().
contiguous
().
t
()
# make col major
return
ops
.
machete_prepack_B
(
w_q
,
wtype
)
return
w_ref
,
w_q
,
w_s
,
w_zp
def
make_bench_tensors
(
atype
:
torch
.
dtype
,
wtype
:
ScalarType
,
group_size
:
int
,
m
:
int
,
n
:
int
,
k
:
int
)
->
Tuple
[
torch
.
tensor
,
List
[
Tuple
[
torch
.
tensor
,
torch
.
tensor
,
torch
.
tensor
,
torch
.
tensor
]]]:
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
def
create_bench_tensors
(
shape
:
tuple
[
int
,
int
,
int
],
types
:
TypeConfig
,
group_size
:
Optional
[
int
]
)
->
list
[
BenchmarkTensors
]:
m
,
n
,
k
=
shape
# we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb
num_weights
=
math
.
ceil
(
2
*
50
*
1024
**
2
*
8
/
(
k
*
n
*
wtype
.
size_bits
))
num_weights
=
math
.
ceil
(
2
*
50
*
1024
**
2
*
8
/
(
k
*
n
*
types
.
weight_type
.
size_bits
)
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
atype
)
*
5
weights
=
[
torch
.
randn
((
k
,
n
),
device
=
"cuda"
,
dtype
=
atype
)
for
_
in
range
(
num_weights
)
]
quanitized_weights
=
[
quantize_weights
(
w
,
wtype
,
group_size
)
for
w
in
weights
]
a
=
rand_data
((
m
,
k
),
types
.
act_type
,
scale
=
5
)
benchmark_tensors
:
list
[
BenchmarkTensors
]
=
[]
for
_
in
range
(
num_weights
):
w
=
rand_data
((
k
,
n
),
types
.
act_type
,
scale
=
5
)
if
types
.
group_scale_type
is
not
None
:
w
=
w
.
to
(
types
.
group_scale_type
)
if
w
.
dtype
.
itemsize
==
1
:
w
=
w
.
to
(
torch
.
float16
)
w_ref
,
w_q_packed
,
w_s
,
w_zp
=
quantize_and_pack
(
a
.
dtype
,
w
,
types
.
weight_type
,
types
.
group_scale_type
,
group_size
,
types
.
group_zero_type
is
not
None
,
)
if
not
a
.
dtype
.
is_floating_point
:
aiinfo
=
torch
.
iinfo
(
a
.
dtype
)
w_ref
=
w_ref
.
round
().
clamp
(
aiinfo
.
min
,
aiinfo
.
max
)
w_ref
=
w_ref
.
to
(
torch
.
float32
)
w_ch_s
=
(
None
if
types
.
channel_scale_type
is
None
else
rand_data
((
n
,),
types
.
channel_scale_type
)
)
w_tok_s
=
(
None
if
types
.
token_scale_type
is
None
else
rand_data
((
m
,),
types
.
token_scale_type
)
)
benchmark_tensors
.
append
(
BenchmarkTensors
(
w_ref
=
w_ref
,
a
=
a
,
w_q
=
w_q_packed
,
wtype
=
types
.
weight_type
,
w_g_s
=
w_s
,
w_g_zp
=
w_zp
,
group_size
=
group_size
,
w_ch_s
=
w_ch_s
,
w_tok_s
=
w_tok_s
,
)
)
return
benchmark_tensors
def
torch_matmul_f16_create_bench_fn
(
bt
:
BenchmarkTensors
)
->
Callable
:
a
=
bt
.
a
w
=
bt
.
w_ref
.
to
(
bt
.
a
.
dtype
)
# use float reference tensor
if
a
.
dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]:
a
=
a
.
to
(
torch
.
float16
)
w
=
w
.
to
(
torch
.
float16
)
return
lambda
:
torch
.
matmul
(
a
,
w
)
def
cutlass_scaled_mm_create_bench_fn
(
bt
:
BenchmarkTensors
)
->
Callable
:
if
bt
.
w_ch_s
is
not
None
and
bt
.
w_tok_s
is
not
None
:
scale_a
=
bt
.
w_tok_s
.
to
(
torch
.
float32
)
scale_b
=
bt
.
w_ch_s
.
to
(
torch
.
float32
)
else
:
scale_a
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
bt
.
a
.
device
)
scale_b
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
bt
.
a
.
device
)
w_col_major
=
bt
.
w_ref
.
to
(
bt
.
a
.
dtype
).
t
().
contiguous
().
t
()
return
lambda
:
ops
.
cutlass_scaled_mm
(
bt
.
a
,
w_col_major
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
float16
)
return
a
,
quanitized_weights
def
marlin_create_bench_fn
(
bt
:
BenchmarkTensors
)
->
Callable
:
device
=
bt
.
a
.
device
# impl
workspace
=
MarlinWorkspace
(
bt
.
w_ref
.
shape
[
1
],
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
if
bt
.
w_g_zp
is
None
:
w_zp
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
else
:
w_zp
=
marlin_zero_points
(
bt
.
w_g_zp
,
bt
.
w_ref
.
shape
[
0
],
bt
.
w_ref
.
shape
[
1
],
bt
.
wtype
.
size_bits
)
if
bt
.
group_size
is
None
:
w_s
=
torch
.
tensor
([],
device
=
"cuda"
,
dtype
=
torch
.
half
)
else
:
w_s
=
marlin_permute_scales
(
bt
.
w_g_s
,
bt
.
w_ref
.
shape
[
0
],
bt
.
w_ref
.
shape
[
1
],
bt
.
group_size
)
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
w_q
=
ops
.
gptq_marlin_repack
(
bt
.
w_q
,
sort_indices
,
bt
.
w_ref
.
shape
[
0
],
bt
.
w_ref
.
shape
[
1
],
bt
.
wtype
.
size_bits
)
if
bt
.
a
.
dtype
.
is_floating_point
:
assert
bt
.
w_ch_s
is
None
assert
bt
.
w_tok_s
is
None
assert
bt
.
group_size
is
not
None
fn
=
lambda
:
ops
.
gptq_marlin_gemm
(
a
=
bt
.
a
,
c
=
None
,
b_q_weight
=
w_q
,
b_scales
=
w_s
,
global_scale
=
None
,
b_zeros
=
w_zp
,
g_idx
=
g_idx
,
perm
=
sort_indices
,
workspace
=
workspace
.
scratch
,
b_q_type
=
bt
.
wtype
,
size_m
=
bt
.
a
.
shape
[
0
],
size_n
=
bt
.
w_ref
.
shape
[
1
],
size_k
=
bt
.
w_ref
.
shape
[
0
],
is_k_full
=
True
,
is_zp_float
=
False
,
)
else
:
assert
bt
.
a
.
dtype
==
torch
.
int8
assert
bt
.
wtype
==
scalar_types
.
uint4b8
if
bt
.
w_ch_s
is
not
None
:
s_ch
=
bt
.
w_ch_s
.
to
(
torch
.
float32
)
else
:
s_ch
=
torch
.
ones
(
bt
.
w_ref
.
shape
[
1
],
dtype
=
torch
.
float32
,
device
=
device
)
if
bt
.
w_tok_s
is
not
None
:
s_tok
=
bt
.
w_tok_s
.
to
(
torch
.
float32
)
else
:
s_tok
=
torch
.
ones
(
bt
.
a
.
shape
[
0
],
dtype
=
torch
.
float32
,
device
=
device
)
fn
=
lambda
:
ops
.
marlin_qqq_gemm
(
a
=
bt
.
a
,
b_q_weight
=
w_q
,
s_group
=
w_s
,
s_tok
=
s_tok
,
s_ch
=
s_ch
,
workspace
=
workspace
.
scratch
,
size_m
=
bt
.
a
.
shape
[
0
],
size_n
=
bt
.
w_ref
.
shape
[
1
],
size_k
=
bt
.
w_ref
.
shape
[
0
],
)
return
fn
def
machete_create_bench_fn
(
bt
:
BenchmarkTensors
,
out_type
=
torch
.
dtype
,
schedule
=
None
)
->
Callable
:
w_q
=
bt
.
w_q
.
t
().
contiguous
().
t
()
# make col major
w_q
=
ops
.
machete_prepack_B
(
w_q
,
bt
.
a
.
dtype
,
bt
.
wtype
,
None
if
bt
.
w_g_s
is
None
else
bt
.
w_g_s
.
dtype
)
w_g_zp
=
bt
.
w_g_zp
if
w_g_zp
is
not
None
:
w_g_zp
=
-
1
*
bt
.
w_g_s
*
(
w_g_zp
.
to
(
bt
.
w_g_s
.
dtype
))
return
lambda
:
ops
.
machete_mm
(
a
=
bt
.
a
,
b_q
=
w_q
,
b_type
=
bt
.
wtype
,
b_group_scales
=
bt
.
w_g_s
,
b_group_zeros
=
w_g_zp
,
b_group_size
=
bt
.
group_size
,
b_channel_scales
=
bt
.
w_ch_s
,
a_token_scales
=
bt
.
w_tok_s
,
out_type
=
out_type
,
schedule
=
schedule
,
)
# impl
# bench
def
bench_fn
(
label
:
str
,
sub_label
:
str
,
description
:
str
,
fn
:
Callable
)
->
TMeasurement
:
min_run_time
=
1
return
TBenchmark
.
Timer
(
stmt
=
"fn()"
,
globals
=
{
"fn"
:
fn
},
def
bench_fns
(
label
:
str
,
sub_label
:
str
,
description
:
str
,
fns
:
list
[
Callable
]):
min_run_time
=
1
if
not
NVTX_PROFILE
else
0.1
res
=
TBenchmark
.
Timer
(
stmt
=
"""
for fn in fns:
fn()
"""
,
globals
=
{
"fns"
:
fns
},
label
=
label
,
sub_label
=
sub_label
,
description
=
description
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
if
NVTX_PROFILE
:
with
(
nvtx
.
annotate
(
"mm-bench"
),
nvtx
.
annotate
(
f
"
{
label
}
|
{
sub_label
}
|
{
description
}
"
),
):
fns
[
0
]()
def
loop_over_weights
(
a
:
torch
.
tensor
,
weights
:
List
[
Tuple
[
torch
.
tensor
,
torch
.
tensor
,
torch
.
tensor
,
torch
.
tensor
]],
fn
:
Callable
[[
torch
.
tensor
,
torch
.
tensor
,
torch
.
tensor
,
torch
.
tensor
],
None
]):
for
w_ref
,
w_q
,
w_s
,
_
in
weights
:
fn
(
a
,
w_ref
,
w_q
,
w_s
)
return
res
_SWEEP_SCHEDULES_RESULTS
:
Optional
[
pd
.
DataFrame
]
=
None
_SWEEP_SCHEDULES_RESULTS_CSV
:
Optional
[
str
]
=
None
def
bench
(
atype
:
torch
.
dtype
,
wtype
:
ScalarType
,
group_size
:
int
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
,
benchmark_marlinv1
:
bool
=
True
,
sweep_schedules
:
bool
=
True
)
->
Iterable
[
TMeasurement
]:
global
_SWEEP_SCHEDULES_RESULTS
a
,
weights
=
make_bench_tensors
(
atype
,
wtype
,
group_size
,
m
,
n
,
k
)
sub_label
+=
f
", L=
{
len
(
weights
)
}
"
weights_machete
=
[(
w_ref
,
machete_pack_weights
(
w_q
,
wtype
),
w_s
,
w_zp
)
for
w_ref
,
w_q
,
w_s
,
w_zp
in
weights
]
def
bench
(
types
:
TypeConfig
,
group_size
:
int
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
,
sweep_schedules
:
bool
=
True
,
)
->
list
[
TMeasurement
]:
benchmark_tensors
=
create_bench_tensors
((
m
,
n
,
k
),
types
,
group_size
)
sub_label
+=
f
", L=
{
len
(
benchmark_tensors
)
}
"
name_type_string
=
f
"W
{
types
.
weight_type
}
"
+
f
"-A
{
terse_type_name
(
types
.
act_type
)
}
"
if
types
.
group_scale_type
is
not
None
:
name_type_string
+=
f
"-GS
{
terse_type_name
(
types
.
group_scale_type
)
}
"
if
types
.
group_zero_type
is
not
None
:
name_type_string
+=
f
"-GZ
{
terse_type_name
(
types
.
group_zero_type
)
}
"
if
group_size
is
not
None
:
name_type_string
+=
f
"-G
{
group_size
}
"
if
types
.
channel_scale_type
is
not
None
:
name_type_string
+=
f
"-CS
{
terse_type_name
(
types
.
channel_scale_type
)
}
"
if
types
.
token_scale_type
is
not
None
:
name_type_string
+=
f
"-TS
{
terse_type_name
(
types
.
token_scale_type
)
}
"
timers
=
[]
# pytorch impl
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"torch.matmul"
,
lambda
:
loop_over_weights
(
a
,
weights
,
lambda
a
,
w_ref
,
w_q
,
w_s
:
torch
.
matmul
(
a
,
w_ref
),
)))
if
benchmark_marlinv1
:
w_ref
=
weights
[
0
][
0
]
w_zp_empty
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w_ref
.
device
)
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w_ref
.
device
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w_ref
.
device
)
def
marlinv1_pack_weights
(
w_q
:
torch
.
tensor
)
->
torch
.
tensor
:
w_q_gptq
=
gptq_pack
(
w_q
,
wtype
.
size_bits
,
*
w_ref
.
shape
)
return
ops
.
gptq_marlin_repack
(
w_q_gptq
,
sort_indices
,
*
w_ref
.
shape
,
wtype
.
size_bits
)
def
marlinv1_permute_scales
(
w_s
:
torch
.
tensor
)
->
torch
.
tensor
:
return
marlin_permute_scales
(
w_s
,
*
w_ref
.
shape
,
group_size
)
weights_marlinv1
=
[(
w_ref
,
marlinv1_pack_weights
(
w_q
),
marlinv1_permute_scales
(
w_s
),
w_zp
)
for
w_ref
,
w_q
,
w_s
,
w_zp
in
weights
]
workspace
=
MarlinWorkspace
(
w_ref
.
shape
[
1
],
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
bench_fns
(
label
,
sub_label
,
"torch.matmul (fp16)"
,
[
torch_matmul_f16_create_bench_fn
(
bt
)
for
bt
in
benchmark_tensors
],
)
)
# marlinv1
if
types
.
act_type
==
torch
.
int8
or
types
.
act_type
==
torch
.
float8_e4m3fn
:
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"marlin_orig"
,
lambda
:
loop_over_weights
(
a
,
weights_marlinv1
,
lambda
a
,
w_ref
,
w_q
,
w_s
:
ops
.
gptq_marlin_gemm
(
a
,
w_q
,
w_s
,
w_zp_empty
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
wtype
,
size_m
=
a
.
shape
[
0
],
size_n
=
w_ref
.
shape
[
1
],
size_k
=
w_ref
.
shape
[
0
],
is_k_full
=
True
))))
bench_fns
(
label
,
sub_label
,
f
"cutlass_scaled_mm (
{
terse_type_name
(
types
.
act_type
)
}
)"
,
[
cutlass_scaled_mm_create_bench_fn
(
bt
)
for
bt
in
benchmark_tensors
],
)
)
if
types
.
act_type
!=
torch
.
float8_e4m3fn
:
timers
.
append
(
bench_fns
(
label
,
sub_label
,
f
"marlin (
{
name_type_string
}
)"
,
[
marlin_create_bench_fn
(
bt
)
for
bt
in
benchmark_tensors
],
)
)
# machete
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"machete_heuristic"
,
lambda
:
loop_over_weights
(
a
,
weights_machete
,
lambda
a
,
_
,
w_q
,
w_s
:
ops
.
machete_gemm
(
a
,
w_q
,
wtype
,
b_scales
=
w_s
,
b_group_size
=
group_size
))))
bench_fns
(
label
,
sub_label
,
f
"machete (
{
name_type_string
}
)"
,
[
machete_create_bench_fn
(
bt
,
out_type
=
types
.
output_type
)
for
bt
in
benchmark_tensors
],
)
)
if
sweep_schedules
:
global
_SWEEP_SCHEDULES_RESULTS
print
(
"Finding best schedule for machete"
)
best
=
None
best_schedule
=
None
schedules
=
ops
.
machete_supported_schedules
(
wtype
)
schedules
=
ops
.
machete_supported_schedules
(
a_type
=
types
.
act_type
,
b_type
=
types
.
weight_type
,
group_scales_type
=
types
.
group_scale_type
,
group_zeros_type
=
types
.
group_zero_type
,
token_scales_type
=
types
.
token_scale_type
,
channel_scales_type
=
types
.
channel_scale_type
,
out_type
=
types
.
output_type
,
)
if
schedules
is
None
or
len
(
schedules
)
==
0
:
raise
ValueError
(
"No schedules found to sweep"
)
for
schedule
in
reversed
(
schedules
):
schedule_M
=
int
(
schedule
.
split
(
"_"
)[
0
].
split
(
"x"
)[
1
])
...
...
@@ -177,16 +431,17 @@ def bench(atype: torch.dtype,
if
schedule_M
>=
2
*
max
(
m
,
16
)
or
schedule_M
<
m
//
4
:
continue
def
run
(
a
,
_
,
w_q
,
w_s
,
schedule
=
schedule
):
ops
.
machete_gemm
(
a
,
w_q
,
wtype
,
w_s
,
b_group_size
=
group_size
,
schedule
=
schedule
)
res
=
bench_fn
(
label
,
sub_label
,
"machete_best"
,
lambda
:
loop_over_weights
(
a
,
weights_machete
,
run
))
res
=
bench_fns
(
label
,
sub_label
,
"machete_best"
,
[
machete_create_bench_fn
(
bt
,
out_type
=
types
.
output_type
,
schedule
=
schedule
)
for
bt
in
benchmark_tensors
],
)
results_row
=
{
"M"
:
m
,
...
...
@@ -197,10 +452,8 @@ def bench(atype: torch.dtype,
"median"
:
res
.
median
,
}
if
_SWEEP_SCHEDULES_RESULTS
is
None
:
_SWEEP_SCHEDULES_RESULTS
=
pd
.
DataFrame
(
columns
=
results_row
.
keys
())
_SWEEP_SCHEDULES_RESULTS
.
\
loc
[
len
(
_SWEEP_SCHEDULES_RESULTS
)]
=
results_row
_SWEEP_SCHEDULES_RESULTS
=
pd
.
DataFrame
(
columns
=
results_row
.
keys
())
_SWEEP_SCHEDULES_RESULTS
.
loc
[
len
(
_SWEEP_SCHEDULES_RESULTS
)]
=
results_row
print
(
f
"
{
res
.
median
:
5.5
}
"
,
schedule
)
if
not
best
or
res
.
median
<
best
.
median
:
...
...
@@ -213,25 +466,36 @@ def bench(atype: torch.dtype,
# runner
def
print_timers
(
timers
:
Iterable
[
TMeasurement
]):
def
print_timers
(
timers
:
list
[
TMeasurement
]):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
def
run
(
dtype
:
torch
.
dtype
,
sweep_schedules
:
bool
,
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]])
->
Iterable
[
TMeasurement
]:
def
run
(
args
,
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]])
->
Iterable
[
TMeasurement
]:
types
=
TypeConfig
(
act_type
=
args
.
act_type
,
weight_type
=
scalar_types
.
uint4b8
if
args
.
group_zero_type
is
None
else
scalar_types
.
uint4
,
output_type
=
args
.
out_type
,
group_scale_type
=
args
.
group_scale_type
,
group_zero_type
=
args
.
group_zero_type
,
channel_scale_type
=
args
.
channel_scale_type
,
token_scale_type
=
args
.
token_scale_type
,
)
results
=
[]
results
:
list
[
TMeasurement
]
=
[]
for
m
,
k
,
n
in
MKNs
:
timers
=
bench
(
dtype
,
scalar_types
.
uint4b8
,
128
,
m
,
k
,
n
,
f
"
{
dtype
}
-gemm"
,
f
"MKN=(
{
m
}
x
{
k
}
x
{
n
}
)"
,
sweep_schedules
=
sweep_schedules
)
timers
=
bench
(
types
,
args
.
group_size
,
m
,
k
,
n
,
f
"
{
args
.
act_type
}
-gemm"
,
f
"MKN=(
{
m
}
x
{
k
}
x
{
n
}
)"
,
sweep_schedules
=
args
.
sweep_schedules
,
)
print_timers
(
timers
)
results
.
extend
(
timers
)
...
...
@@ -240,12 +504,11 @@ def run(dtype: torch.dtype, sweep_schedules: bool,
# output makers
def
make_output
(
data
:
Iterable
[
TMeasurement
],
MKNs
:
Iterable
[
T
uple
[
int
,
int
,
int
]],
data
:
list
[
TMeasurement
],
MKNs
:
Iterable
[
t
uple
[
int
,
int
,
int
]],
base_description
:
str
,
timestamp
=
None
,
):
print
(
f
"== All Results
{
base_description
}
===="
)
print_timers
(
data
)
...
...
@@ -259,20 +522,19 @@ def make_output(
def
run_square_bench
(
args
):
dim_sizes
=
list
(
range
(
args
.
dim_start
,
args
.
dim_end
+
1
,
args
.
dim_increment
))
dim_sizes
=
list
(
range
(
args
.
dim_start
,
args
.
dim_end
+
1
,
args
.
dim_increment
))
MKNs
=
list
(
zip
(
dim_sizes
,
dim_sizes
,
dim_sizes
))
data
=
run
(
args
.
dtype
,
args
.
sweep_schedules
,
MKNs
)
make_output
(
data
,
MKNs
,
f
"square_bench-
{
args
.
dtype
}
"
)
def
run_range_bench
(
args
):
m_start
,
k_start
,
n_start
=
[
int
(
x
)
for
x
in
args
.
dim_start
.
split
(
","
)]
m_end
,
k_end
,
n_end
=
[
int
(
x
)
for
x
in
args
.
dim_end
.
split
(
","
)]
m_increment
,
k_increment
,
n_increment
=
\
[
int
(
x
)
for
x
in
args
.
dim_increment
.
split
(
","
)]
m_start
,
k_start
,
n_start
=
(
int
(
x
)
for
x
in
args
.
dim_start
.
split
(
","
))
m_end
,
k_end
,
n_end
=
(
int
(
x
)
for
x
in
args
.
dim_end
.
split
(
","
))
m_increment
,
k_increment
,
n_increment
=
(
int
(
x
)
for
x
in
args
.
dim_increment
.
split
(
","
)
)
Ms
=
list
(
range
(
m_start
,
m_end
+
1
,
m_increment
))
Ks
=
list
(
range
(
k_start
,
k_end
+
1
,
k_increment
))
Ns
=
list
(
range
(
n_start
,
n_end
+
1
,
n_increment
))
...
...
@@ -284,12 +546,11 @@ def run_range_bench(args):
def
run_model_bench
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
def
model_shapes
(
model_name
:
str
,
tp_size
:
int
)
->
L
ist
[
T
uple
[
int
,
int
]]:
def
model_shapes
(
model_name
:
str
,
tp_size
:
int
)
->
l
ist
[
t
uple
[
int
,
int
]]:
KNs
=
[]
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model_name
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
...
...
@@ -306,33 +567,51 @@ def run_model_bench(args):
for
k
,
n
in
KNs
:
MKNs
.
append
((
m
,
k
,
n
))
data
=
run
(
args
.
dtype
,
args
.
sweep_schedules
,
MKNs
)
data
=
run
(
args
,
MKNs
)
model_bench_data
.
append
(
data
)
type_string
=
f
"
{
args
.
act_type
}
"
# Print all results
for
data
,
model_tp
in
zip
(
model_bench_data
,
models_tps
):
model
,
tp_size
=
model_tp
print
(
f
"== Results
{
args
.
dtype
}
{
model
}
-TP
{
tp_size
}
===="
)
print
(
f
"== Results
{
type_string
}
{
model
}
-TP
{
tp_size
}
===="
)
print_timers
(
data
)
timest
amp
=
int
(
time
.
time
(
)
)
timest
r
=
time
.
strf
time
(
"%Y%m%d-%H%M%S"
)
all_
data
=
[]
all_
results
=
[]
for
d
in
model_bench_data
:
all_data
.
extend
(
d
)
all_results
.
extend
(
d
)
# pickle all data
with
open
(
f
"model_bench-
{
args
.
dtype
}
-
{
timestamp
}
.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
(
all_data
,
f
)
with
open
(
f
"model_bench-
{
type_string
}
-
{
timestr
}
.pkl"
,
"wb"
)
as
f
:
args_dict
=
vars
(
args
)
args_dict
.
pop
(
"func"
)
pkl
.
dump
(
{
"args"
:
args_dict
,
"results"
:
all_results
,
},
f
,
)
if
__name__
==
"__main__"
:
def
to_torch_dtype
(
dt
):
if
dt
==
"bfloat16"
:
return
torch
.
bfloat16
if
dt
==
"float16"
:
return
torch
.
float16
raise
ValueError
(
"unsupported dtype"
)
return
{
"bfloat16"
:
torch
.
bfloat16
,
"float16"
:
torch
.
float16
,
"int8"
:
torch
.
int8
,
"float8_e4m3fn"
:
torch
.
float8_e4m3fn
,
"int"
:
torch
.
int
,
"float"
:
torch
.
float
,
}[
dt
]
class
ToTorchDtype
(
argparse
.
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
setattr
(
namespace
,
self
.
dest
,
to_torch_dtype
(
values
))
parser
=
FlexibleArgumentParser
(
description
=
"""
...
...
@@ -352,21 +631,53 @@ Benchmark Machete GEMM.
"""
,
# noqa: E501
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
parser
.
add_argument
(
"--
d
type"
,
type
=
to_t
orch
_d
type
,
"--
act-
type"
,
action
=
ToT
orch
D
type
,
required
=
True
,
help
=
"Available options are ['bfloat16', 'float16']"
,
choices
=
[
"bfloat16"
,
"float16"
,
"int8"
,
"float8_e4m3fn"
],
)
parser
.
add_argument
(
"--group-scale-type"
,
action
=
ToTorchDtype
,
choices
=
[
"bfloat16"
,
"float16"
],
)
parser
.
add_argument
(
"--group-zero-type"
,
type
=
to_torch_dtype
,
choices
=
[
"bfloat16"
,
"float16"
],
)
parser
.
add_argument
(
"--channel-scale-type"
,
action
=
ToTorchDtype
,
choices
=
[
"float"
],
)
parser
.
add_argument
(
"--token-scale-type"
,
action
=
ToTorchDtype
,
choices
=
[
"float"
],
)
parser
.
add_argument
(
"--out-type"
,
action
=
ToTorchDtype
,
choices
=
[
"bfloat16"
,
"float16"
],
)
parser
.
add_argument
(
"--group-size"
,
type
=
int
,
help
=
"Available options are ['None', '-1', '128'], default=128"
,
default
=
128
,
)
parser
.
add_argument
(
"--sweep-schedules"
,
action
=
"store_true"
,
help
=
"Run a sweep over all supported schedules"
,
)
parser
.
add_argument
(
"--sweep-csv-out"
,
help
=
"CSV to store sweep results"
,
default
=
"sch_sweep_results.csv"
)
parser
.
add_argument
(
"--sweep-csv-out"
,
help
=
"CSV to store sweep results"
,
default
=
"sch_sweep_results.csv"
,
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
,
required
=
True
)
square_parser
=
subparsers
.
add_parser
(
"square_bench"
)
...
...
@@ -380,17 +691,20 @@ Benchmark Machete GEMM.
"--dim-start"
,
type
=
str
,
required
=
True
,
help
=
"Start value for M,K,N as common separated list"
)
help
=
"Start value for M,K,N as common separated list"
,
)
range_parser
.
add_argument
(
"--dim-end"
,
type
=
str
,
required
=
True
,
help
=
"End value (inclusive) for M,K,N as common separated list"
)
help
=
"End value (inclusive) for M,K,N as common separated list"
,
)
range_parser
.
add_argument
(
"--dim-increment"
,
type
=
str
,
required
=
True
,
help
=
"Increment value for M,K,N as common separated list"
)
help
=
"Increment value for M,K,N as common separated list"
,
)
range_parser
.
set_defaults
(
func
=
run_range_bench
)
model_parser
=
subparsers
.
add_parser
(
"model_bench"
)
...
...
@@ -401,14 +715,12 @@ Benchmark Machete GEMM.
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
model_parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
model_parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
model_parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
model_parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
model_parser
.
set_defaults
(
func
=
run_model_bench
)
args
=
parser
.
parse_args
()
...
...
benchmarks/kernels/benchmark_marlin.py
View file @
c2170174
from
typing
import
List
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.utils.benchmark
as
benchmark
...
...
@@ -6,78 +7,202 @@ from benchmark_shapes import WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
)
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
,
)
from
vllm.model_executor.layers.quantization.utils.allspark_utils
import
(
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
ALLSPARK_SUPPORTED_QUANT_TYPES
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
MARLIN_SUPPORTED_GROUP_SIZES
,
query_marlin_supported_quant_types
)
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
MARLIN_SUPPORTED_GROUP_SIZES
,
query_marlin_supported_quant_types
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
FP4_MARLIN_SUPPORTED_GROUP_SIZES
,
rand_marlin_weight_fp4_like
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
marlin_quant_fp8_torch
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
marlin_quantize
)
MarlinWorkspace
,
awq_marlin_quantize
,
marlin_quantize
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
marlin_24_quantize
)
marlin_24_quantize
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
gptq_quantize_weights
,
sort_weights
)
from
vllm.scalar_type
import
ScalarType
gptq_pack
,
gptq_quantize_weights
,
quantize_weights
,
sort_weights
,
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
def
bench_run
(
results
:
List
[
benchmark
.
Measurement
],
model
:
str
,
act_order
:
bool
,
is_k_full
:
bool
,
quant_type
:
ScalarType
,
group_size
:
int
,
size_m
:
int
,
size_k
:
int
,
size_n
:
int
):
def
bench_run
(
results
:
list
[
benchmark
.
Measurement
],
model
:
str
,
act_order
:
bool
,
is_k_full
:
bool
,
quant_type
:
ScalarType
,
group_size
:
int
,
size_m
:
int
,
size_k
:
int
,
size_n
:
int
,
):
label
=
"Quant Matmul"
sub_label
=
(
"{}, act={} k_full={}, q={}, g={}, "
"MKN=({}x{}x{})"
.
format
(
model
,
act_order
,
is_k_full
,
str
(
quant_type
),
group_size
,
size_m
,
size_k
,
size_n
))
sub_label
=
"{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})"
.
format
(
model
,
act_order
,
is_k_full
,
str
(
quant_type
),
group_size
,
size_m
,
size_k
,
size_n
)
print
(
f
"Testing:
{
sub_label
}
"
)
a
=
torch
.
randn
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
()
b
=
torch
.
rand
(
size_k
,
size_n
).
to
(
torch
.
half
).
cuda
()
has_zp
=
quant_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
if
act_order
and
(
group_size
==
-
1
or
group_size
==
size_k
or
has_zp
):
return
if
size_k
%
group_size
!=
0
:
return
marlin_24_supported
=
(
quant_type
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and
group_size
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
repack_supported
=
(
quant_type
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and
group_size
in
MARLIN_SUPPORTED_GROUP_SIZES
)
allspark_supported
=
(
quant_type
in
ALLSPARK_SUPPORTED_QUANT_TYPES
and
group_size
==
-
1
and
not
act_order
and
is_k_full
)
a_tmp
=
(
torch
.
zeros
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
())
def
gen_marlin_params
():
# Marlin quant
marlin_g_idx
=
marlin_sort_indices
=
marlin_zp
=
marlin_s2
=
None
if
quant_type
==
scalar_types
.
float4_e2m1f
:
if
group_size
!=
16
or
act_order
:
return
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
=
rand_marlin_weight_fp4_like
(
b
.
T
,
group_size
)
elif
quant_type
==
scalar_types
.
float8_e4m3fn
:
if
group_size
not
in
[
-
1
,
128
]
or
act_order
:
return
marlin_w_ref
,
marlin_q_w
,
marlin_s
=
marlin_quant_fp8_torch
(
b
.
T
,
group_size
)
elif
group_size
==
16
:
return
elif
has_zp
:
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
=
awq_marlin_quantize
(
b
,
quant_type
,
group_size
)
else
:
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_g_idx
,
marlin_sort_indices
,
_
=
(
marlin_quantize
(
b
,
quant_type
,
group_size
,
act_order
)
)
return
(
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
,
marlin_zp
,
marlin_g_idx
,
marlin_sort_indices
,
)
def
gen_marlin_24_params
():
marlin_24_w_ref
=
marlin_24_q_w_comp
=
marlin_24_meta
=
marlin_24_s
=
None
if
marlin_24_supported
:
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
=
(
marlin_24_quantize
(
b
,
quant_type
,
group_size
)
)
return
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
def
gen_repack_params
():
q_w_gptq
=
None
repack_sort_indices
=
None
if
repack_supported
:
(
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
)
=
gptq_quantize_weights
(
b
,
quant_type
,
group_size
,
act_order
)
q_w_gptq
=
gptq_pack
(
q_w
,
quant_type
.
size_bits
,
size_k
,
size_n
)
# For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing
repack_sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b
.
device
)
if
act_order
:
(
q_w
,
g_idx
,
repack_sort_indices
)
=
sort_weights
(
q_w
,
g_idx
)
return
q_w_gptq
,
repack_sort_indices
def
gen_allspark_params
():
qw_reorder
=
s_reorder
=
zp_reorder
=
sm_count
=
sm_version
=
(
CUBLAS_M_THRESHOLD
)
=
None
nonlocal
allspark_supported
if
allspark_supported
:
properties
=
torch
.
cuda
.
get_device_properties
(
b
.
device
.
index
)
sm_count
=
properties
.
multi_processor_count
sm_version
=
properties
.
major
*
10
+
properties
.
minor
supported_arch
=
sm_version
>=
80
and
sm_version
<
90
allspark_supported
=
allspark_supported
and
supported_arch
if
supported_arch
:
w_ref
,
qw
,
s
,
zp
=
quantize_weights
(
b
,
quant_type
,
group_size
,
has_zp
)
qw
=
qw
.
to
(
torch
.
uint8
)
qw_reorder
,
s_reorder
,
zp_reorder
=
ops
.
allspark_repack_weight
(
qw
,
s
,
zp
,
has_zp
)
CUBLAS_M_THRESHOLD
=
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
return
(
qw_reorder
,
s_reorder
,
zp_reorder
,
sm_count
,
sm_version
,
CUBLAS_M_THRESHOLD
,
)
# Marlin quant
(
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
,
marlin_zp
,
marlin_g_idx
,
marlin_sort_indices
,
marlin_rand_perm
,
)
=
marlin_quantize
(
b
,
quant_type
,
group_size
,
act_order
)
# Marlin_24 quant
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
=
marlin_24_quantize
(
b
,
quant_type
,
group_size
)
marlin_zp
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b
.
device
)
# GPTQ quant
(
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
)
=
gptq_quantize_weights
(
b
,
quant_type
,
group_size
,
act_order
)
q_w_gptq
=
gptq_pack
(
q_w
,
quant_type
.
size_bits
,
size_k
,
size_n
)
# For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing
repack_sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b
.
device
)
if
act_order
:
(
q_w
,
g_idx
,
repack_sort_indices
)
=
sort_weights
(
q_w
,
g_idx
)
)
=
gen_marlin_params
()
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
=
(
gen_marlin_24_params
()
)
q_w_gptq
,
repack_sort_indices
=
gen_repack_params
()
qw_reorder
,
s_reorder
,
zp_reorder
,
sm_count
,
sm_version
,
CUBLAS_M_THRESHOLD
=
(
gen_allspark_params
()
)
# Prepare
marlin_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
marlin_24_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
)
marlin_zp
=
torch
.
zeros_like
(
marlin_s
,
dtype
=
torch
.
int
)
marlin_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
marlin_24_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
)
globals
=
{
# Gen params
...
...
@@ -87,15 +212,14 @@ def bench_run(results: List[benchmark.Measurement], model: str,
"size_n"
:
size_n
,
"size_k"
:
size_k
,
"a"
:
a
,
"a_tmp"
:
a_tmp
,
# Marlin params
"marlin_w_ref"
:
marlin_w_ref
,
"marlin_q_w"
:
marlin_q_w
,
"marlin_s"
:
marlin_s
,
"marlin_s2"
:
marlin_s2
,
"marlin_zp"
:
marlin_zp
,
"marlin_g_idx"
:
marlin_g_idx
,
"marlin_sort_indices"
:
marlin_sort_indices
,
"marlin_rand_perm"
:
marlin_rand_perm
,
"marlin_workspace"
:
marlin_workspace
,
"is_k_full"
:
is_k_full
,
# Marlin_24 params
...
...
@@ -107,16 +231,24 @@ def bench_run(results: List[benchmark.Measurement], model: str,
# GPTQ params
"q_w_gptq"
:
q_w_gptq
,
"repack_sort_indices"
:
repack_sort_indices
,
# AllSpark W8A16 params
"qw_reorder"
:
qw_reorder
,
"s_reorder"
:
s_reorder
,
"zp_reorder"
:
zp_reorder
,
"sm_count"
:
sm_count
,
"sm_version"
:
sm_version
,
"CUBLAS_M_THRESHOLD"
:
CUBLAS_M_THRESHOLD
,
# Kernels
"gptq_marlin_gemm"
:
ops
.
gptq_marlin_gemm
,
"gptq_marlin_24_gemm"
:
ops
.
gptq_marlin_24_gemm
,
"gptq_marlin_repack"
:
ops
.
gptq_marlin_repack
,
"allspark_w8a16_gemm"
:
ops
.
allspark_w8a16_gemm
,
}
min_run_time
=
1
# Warmup pytorch
for
i
in
range
(
5
):
for
_
in
range
(
5
):
torch
.
matmul
(
a
,
marlin_w_ref
)
results
.
append
(
...
...
@@ -126,57 +258,68 @@ def bench_run(results: List[benchmark.Measurement], model: str,
label
=
label
,
sub_label
=
sub_label
,
description
=
"pytorch_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)"
,
# noqa: E501
stmt
=
"output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_gemm_fp16"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
description
=
"gptq_marlin_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)"
,
# noqa: E501
stmt
=
"output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_gemm_fp32"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
if
(
quant_type
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and
group_size
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
):
if
marlin_24_supported
:
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)"
,
# noqa: E501
stmt
=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_24_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_repack"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
if
repack_supported
:
results
.
append
(
benchmark
.
Timer
(
stmt
=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_repack"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
if
allspark_supported
:
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"allspark_w8a16_gemm_fp32"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
def
main
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
results
:
List
[
benchmark
.
Measurement
]
=
[]
results
:
list
[
benchmark
.
Measurement
]
=
[]
for
model
in
args
.
models
:
for
layer
in
WEIGHT_SHAPES
[
model
]:
...
...
@@ -190,37 +333,53 @@ def main(args):
continue
for
act_order
in
ACT_ORDER_OPTS
:
if
len
(
args
.
limit_act_order
)
>
0
and
act_order
not
in
args
.
limit_act_order
:
if
(
len
(
args
.
limit_act_order
)
>
0
and
act_order
not
in
args
.
limit_act_order
):
continue
for
is_k_full
in
K_FULL_OPTS
:
if
len
(
args
.
limit_k_full
)
>
0
and
is_k_full
not
in
args
.
limit_k_full
:
if
(
len
(
args
.
limit_k_full
)
>
0
and
is_k_full
not
in
args
.
limit_k_full
):
continue
for
quant_type
in
query_marlin_supported_quant_types
(
False
):
if
len
(
args
.
limit_num_bits
)
>
0
and
\
quant_type
.
size_bits
not
in
args
.
limit_num_bits
:
for
quant_type
in
query_marlin_supported_quant_types
():
if
(
len
(
args
.
limit_num_bits
)
>
0
and
quant_type
.
size_bits
not
in
args
.
limit_num_bits
):
continue
for
group_size
in
MARLIN_SUPPORTED_GROUP_SIZES
:
if
len
(
args
.
limit_group_size
)
>
0
and
group_size
not
in
args
.
limit_group_size
:
for
group_size
in
(
MARLIN_SUPPORTED_GROUP_SIZES
+
FP4_MARLIN_SUPPORTED_GROUP_SIZES
):
if
(
len
(
args
.
limit_group_size
)
>
0
and
group_size
not
in
args
.
limit_group_size
):
continue
# For act_order, the group_size must be less than
# size_k
if
act_order
and
(
group_size
==
size_k
or
group_size
==
-
1
):
if
act_order
and
(
group_size
==
size_k
or
group_size
==
-
1
):
continue
for
size_m
in
args
.
batch_sizes
:
bench_run
(
results
,
model
,
act_order
,
is_k_full
,
quant_type
,
group_size
,
size_m
,
size_k
,
size_n
)
bench_run
(
results
,
model
,
act_order
,
is_k_full
,
quant_type
,
group_size
,
size_m
,
size_k
,
size_n
,
)
compare
=
benchmark
.
Compare
(
results
)
compare
.
print
()
...
...
@@ -231,7 +390,8 @@ def main(args):
#
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark Marlin across specified models/shapes/batches"
)
description
=
"Benchmark Marlin across specified models/shapes/batches"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
...
...
@@ -239,10 +399,9 @@ if __name__ == "__main__":
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-group-size"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
...
...
benchmarks/kernels/benchmark_moe.py
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
json
import
time
from
contextlib
import
nullcontext
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
TypedDict
from
itertools
import
product
from
typing
import
Any
,
TypedDict
,
Optional
import
ray
import
torch
import
triton
from
ray.experimental.tqdm_ray
import
tqdm
from
transformers
import
AutoConfig
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.utils
import
FlexibleArgumentParser
,
seed_everything
from
vllm.transformers_utils.config
import
get_config
from
vllm.triton_utils
import
triton
from
vllm.utils
import
FlexibleArgumentParser
# 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype()
class
BenchmarkConfig
(
TypedDict
):
...
...
@@ -20,6 +29,7 @@ class BenchmarkConfig(TypedDict):
GROUP_SIZE_M
:
int
num_warps
:
int
num_stages
:
int
num_ldmatrixes
:
Optional
[
int
]
def
benchmark_config
(
...
...
@@ -33,77 +43,167 @@ def benchmark_config(
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
block_quant_shape
:
list
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
float
:
from
vllm.platforms
import
current_platform
device
=
torch
.
cuda
.
current_device
()
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
if
use_int8_w8a16
:
w1
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
),
dtype
=
torch
.
int8
)
w2
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
),
dtype
=
torch
.
int8
)
if
not
nn_moe
:
w1
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
),
dtype
=
torch
.
int8
,
device
=
device
,
)
w2
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
),
dtype
=
torch
.
int8
,
device
=
device
,
)
else
:
w1
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
hidden_size
,
shard_intermediate_size
,
),
dtype
=
torch
.
int8
,
device
=
device
,
)
w2
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
,
),
dtype
=
torch
.
int8
,
device
=
device
,
)
else
:
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
)
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
if
not
nn_moe
:
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
,
device
=
device
)
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
,
device
=
device
)
else
:
w1
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
,
dtype
=
init_dtype
,
device
=
device
)
w2
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
,
dtype
=
init_dtype
,
device
=
device
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
w1_scale
=
None
w2_scale
=
None
a1_scale
=
None
a2_scale
=
None
if
use_int8_w8a16
:
w1_scale
=
torch
.
randn
((
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
w1_scale
=
torch
.
randn
(
(
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
,
device
=
device
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
device
)
if
use_fp8_w8a8
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
if
block_quant_shape
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
E
=
num_experts
N
=
shard_intermediate_size
//
2
K
=
hidden_size
factor_for_scale
=
1e-2
n_tiles_w1
=
(
2
*
N
+
block_n
-
1
)
//
block_n
n_tiles_w2
=
(
K
+
block_n
-
1
)
//
block_n
k_tiles_w1
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles_w2
=
(
N
+
block_k
-
1
)
//
block_k
w1_scale
=
(
torch
.
rand
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
,
device
=
device
)
*
factor_for_scale
)
w2_scale
=
(
torch
.
rand
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
,
device
=
device
)
*
factor_for_scale
)
else
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
w1
=
w1
.
to
(
torch
.
float8_e4m3fn
)
w2
=
w2
.
to
(
torch
.
float8_e4m3fn
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
input_gating
=
torch
.
empty
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
# 获取 FP8_DTYPE
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
w1
=
w1
.
to
(
FP8_DTYPE
)
w2
=
w2
.
to
(
FP8_DTYPE
)
input_gating
=
torch
.
empty
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
fused_moe
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
True
,
inplace
=
True
,
override_config
=
config
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
from
vllm.model_executor.layers.fused_moe
import
override_config
with
override_config
(
config
):
if
use_deep_gemm
:
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
x
,
input_gating
,
topk
,
False
)
return
fused_experts
(
x
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_quant_shape
,
allow_deep_gemm
=
True
,
use_nn_moe
=
nn_moe
,
)
else
:
fused_moe
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_quant_shape
,
use_nn_moe
=
nn_moe
,
)
# JIT compilation & warmup
run
()
...
...
@@ -119,18 +219,20 @@ def benchmark_config(
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
# run()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
L
ist
[
float
]
=
[]
latencies
:
l
ist
[
float
]
=
[]
for
i
in
range
(
num_iters
):
prepare
(
i
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
# run()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
...
...
@@ -139,35 +241,217 @@ def benchmark_config(
return
avg
def
get_configs_compute_bound
()
->
List
[
Dict
[
str
,
int
]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs
:
List
[
BenchmarkConfig
]
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
]:
for
block_m
in
[
16
,
32
,
64
,
128
,
256
]:
for
block_k
in
[
64
,
128
,
256
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
for
num_warps
in
[
4
,
8
]:
for
group_size
in
[
1
,
16
,
32
,
64
]:
configs
.
append
({
"BLOCK_SIZE_M"
:
block_m
,
"BLOCK_SIZE_N"
:
block_n
,
"BLOCK_SIZE_K"
:
block_k
,
"GROUP_SIZE_M"
:
group_size
,
"num_warps"
:
num_warps
,
"num_stages"
:
num_stages
,
})
def
get_rocm_tuning_space
(
use_fp16
,
nn_moe
:
Optional
[
bool
]
=
False
):
block_m_range
=
[
16
,
32
,
64
,
128
,
256
]
block_n_range
=
[
32
,
64
,
128
,
256
]
block_k_range
=
[
32
,
64
,
128
,
256
]
if
not
use_fp16
:
block_k_range
.
remove
(
16
)
# BLOCK_K=16 not supported for fp8
num_warps_range
=
[
2
,
4
,
8
]
group_m_range
=
[
1
,
16
,
32
,
64
]
num_stage_range
=
[
2
,
3
,
4
,
5
]
# waves_per_eu_range = [0]
# matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
# kpack_range = [1, 2] if use_fp16 else []
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_m_range
,
"BLOCK_SIZE_N"
:
block_n_range
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stage_range
,
# "waves_per_eu": waves_per_eu_range,
}
if
nn_moe
:
param_ranges
[
"num_ldmatrixes"
]
=
[
1
]
# DCU currently does not support the following parameters
# if use_fp16:
# param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
# param_ranges["kpack"] = kpack_range
return
param_ranges
def
get_configs_compute_bound
(
use_fp16
,
block_quant_shape
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
list
[
dict
[
str
,
int
]]:
configs
:
list
[
BenchmarkConfig
]
=
[]
# 局部导入 current_platform
from
vllm.platforms
import
current_platform
if
current_platform
.
is_rocm
():
param_ranges
=
get_rocm_tuning_space
(
use_fp16
,
nn_moe
)
else
:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
block_m_range
=
[
16
,
32
,
64
,
128
,
256
]
block_n_range
=
[
32
,
64
,
128
,
256
]
block_k_range
=
[
64
,
128
,
256
]
num_warps_range
=
[
4
,
8
]
group_m_range
=
[
1
,
16
,
32
,
64
]
num_stage_range
=
[
2
,
3
,
4
,
5
]
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_m_range
,
"BLOCK_SIZE_N"
:
block_n_range
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stage_range
,
}
keys
,
values
=
zip
(
*
param_ranges
.
items
())
for
config_values
in
product
(
*
values
):
config
=
dict
(
zip
(
keys
,
config_values
))
configs
.
append
(
config
)
# Remove configs that are not compatible with fp8 block quantization
# BLOCK_SIZE_K must be a multiple of block_k
# BLOCK_SIZE_N must be a multiple of block_n
if
block_quant_shape
is
not
None
and
not
use_fp16
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
for
config
in
configs
[:]:
if
(
config
[
"BLOCK_SIZE_K"
]
%
block_k
!=
0
or
config
[
"BLOCK_SIZE_N"
]
%
block_n
!=
0
):
configs
.
remove
(
config
)
return
configs
def
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
,
topk
):
N1
,
K1
=
shard_intermediate_size
,
hidden_size
N2
,
K2
=
hidden_size
,
shard_intermediate_size
//
2
pruned_space_1
=
prune_rocm_configs
(
num_tokens
*
topk
,
N1
,
K1
,
search_space
,
is_fp16
)
pruned_space_2
=
prune_rocm_configs
(
num_tokens
*
topk
,
N2
,
K2
,
search_space
,
is_fp16
)
search_space
=
merge_unique_dicts
(
pruned_space_1
,
pruned_space_2
)
return
search_space
# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def
prune_rocm_configs
(
M
,
N
,
K
,
configs
,
is_fp16
=
True
):
pruned_configs
=
[]
elemBytes_a
=
2
if
is_fp16
else
1
elemBytes_b
=
2
if
is_fp16
else
1
mfma
=
16
if
M
<
32
or
N
<
32
else
32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm
=
False
if
M
>=
2048
and
N
>=
2048
:
large_gemm
=
True
for
config
in
configs
:
BLOCK_SIZE_M
=
config
.
get
(
"BLOCK_SIZE_M"
)
BLOCK_SIZE_N
=
config
.
get
(
"BLOCK_SIZE_N"
)
BLOCK_SIZE_K
=
config
.
get
(
"BLOCK_SIZE_K"
)
num_warps
=
config
.
get
(
"num_warps"
)
# DCU currently does not support matrix_instr_nonkdim param
# if is_fp16:
# matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
# if matrix_instr_nonkdim > mfma:
# continue
if
mfma
==
4
and
BLOCK_SIZE_K
<
64
:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if
BLOCK_SIZE_M
*
BLOCK_SIZE_N
<
64
:
continue
SPLIT_K
=
config
.
get
(
"SPLIT_K"
,
1
)
GROUP_M
=
config
.
get
(
"GROUP_SIZE_M"
)
# DCU currently does not support matrix_instr_nonkdim param
# if is_fp16:
# if (
# matrix_instr_nonkdim > BLOCK_SIZE_M
# or matrix_instr_nonkdim > BLOCK_SIZE_N
# ):
# continue
# if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
# continue
# if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
# continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if
M
*
2
<
BLOCK_SIZE_M
and
BLOCK_SIZE_M
!=
16
:
continue
if
N
*
2
<
BLOCK_SIZE_N
and
BLOCK_SIZE_N
!=
16
:
continue
# skip large split_k when not necessary
if
SPLIT_K
!=
1
and
not
need_split_k
(
M
,
N
,
K
):
continue
# skip split_k that leads to EVEN_K = false
leap
=
SPLIT_K
*
BLOCK_SIZE_K
modv
=
K
%
leap
if
modv
!=
0
:
continue
# skip large GROUP_M
if
GROUP_M
*
BLOCK_SIZE_M
>
M
and
GROUP_M
!=
1
:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS
=
(
BLOCK_SIZE_K
*
BLOCK_SIZE_M
*
elemBytes_a
+
BLOCK_SIZE_K
*
BLOCK_SIZE_N
*
elemBytes_b
)
if
LDS
>
65536
:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if
large_gemm
:
if
BLOCK_SIZE_M
<
64
or
BLOCK_SIZE_N
<
64
:
continue
if
BLOCK_SIZE_K
<
64
:
continue
if
num_warps
<
4
:
continue
pruned_configs
.
append
(
config
)
return
pruned_configs
def
need_split_k
(
SIZE_M
,
SIZE_N
,
SIZE_K
):
return
(
SIZE_M
<
64
or
SIZE_N
<
64
)
and
SIZE_K
>
1024
def
merge_unique_dicts
(
list1
,
list2
):
result
=
[]
combined_list
=
list1
.
copy
()
combined_list
.
extend
(
list2
)
for
dictionary
in
combined_list
:
if
dictionary
not
in
result
:
result
.
append
(
dictionary
)
return
result
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
seed_everything
(
seed
)
def
__init__
(
self
,
seed
:
int
,
device_id
:
int
)
->
None
:
from
vllm.platforms
import
current_platform
import
os
if
current_platform
.
is_rocm
():
# In ROCm environment with Ray, let Ray handle device assignment
# Don't manually set default device as it may conflict with Ray's device mapping
pass
else
:
torch
.
set_default_device
(
"cuda:"
+
str
(
device_id
))
current_platform
.
seed_everything
(
seed
)
self
.
seed
=
seed
# Store the logical device ID for Ray
self
.
device_id
=
device_id
def
benchmark
(
self
,
...
...
@@ -179,26 +463,53 @@ class BenchmarkWorker:
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
)
->
Tuple
[
Dict
[
str
,
int
],
float
]:
seed_everything
(
self
.
seed
)
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
block_quant_shape
:
list
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
# 局部导入 current_platform
from
vllm.platforms
import
current_platform
current_platform
.
seed_everything
(
self
.
seed
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
get_config_dtype_str
,
get_moe_configs
,
get_default_config
)
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config
=
get_moe_configs
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
)
op_config
=
get_moe_configs
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
,
use_nn_moe
=
nn_moe
)
if
op_config
is
None
:
config
=
get_default_config
(
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype_str
)
config
=
get_default_config
(
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype_str
,
is_marlin
=
False
,
use_nn_moe
=
nn_moe
)
else
:
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
)
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
block_quant_shape
=
block_quant_shape
,
use_deep_gemm
=
use_deep_gemm
,
use_nn_moe
=
nn_moe
)
return
config
,
kernel_time
def
tune
(
...
...
@@ -211,29 +522,63 @@ class BenchmarkWorker:
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
search_space
:
List
[
Dict
[
str
,
int
]],
)
->
Dict
[
str
,
int
]:
search_space
:
list
[
dict
[
str
,
int
]],
block_quant_shape
:
list
[
int
],
use_deep_gemm
:
bool
,
nn_moe
:
Optional
[
bool
]
=
False
,
)
->
dict
[
str
,
int
]:
from
vllm.platforms
import
current_platform
import
os
best_config
=
None
best_time
=
float
(
"inf"
)
for
config
in
tqdm
(
search_space
):
try
:
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
10
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
continue
if
kernel_time
<
best_time
:
best_time
=
kernel_time
best_config
=
config
if
current_platform
.
is_rocm
():
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
)
search_space
=
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
,
topk
,
)
# In ROCm environments with Ray, device context is already handled by Ray
# Using torch.cuda.device() may cause device ordinal conflicts
need_device_guard
=
False
if
current_platform
.
is_rocm
():
# For ROCm with Ray, skip additional device context management
need_device_guard
=
False
else
:
# For other platforms, use device guard if needed
visible_devices
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
None
)
if
visible_devices
is
not
None
and
len
(
visible_devices
.
split
(
','
))
>
1
:
need_device_guard
=
True
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
need_device_guard
else
nullcontext
():
for
config
in
tqdm
(
search_space
):
try
:
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
20
,
block_quant_shape
=
block_quant_shape
,
use_deep_gemm
=
use_deep_gemm
,
nn_moe
=
nn_moe
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
continue
if
kernel_time
<
best_time
:
best_time
=
kernel_time
best_config
=
config
now
=
datetime
.
now
()
print
(
f
"
{
now
.
ctime
()
}
] Completed tuning for batch_size=
{
num_tokens
}
"
)
assert
best_config
is
not
None
...
...
@@ -241,6 +586,7 @@ class BenchmarkWorker:
def
sort_config
(
config
:
BenchmarkConfig
)
->
BenchmarkConfig
:
return
{
"BLOCK_SIZE_M"
:
config
[
"BLOCK_SIZE_M"
],
"BLOCK_SIZE_N"
:
config
[
"BLOCK_SIZE_N"
],
...
...
@@ -248,21 +594,46 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
"GROUP_SIZE_M"
:
config
[
"GROUP_SIZE_M"
],
"num_warps"
:
config
[
"num_warps"
],
"num_stages"
:
config
[
"num_stages"
],
**
(
{
"num_ldmatrixes"
:
config
[
"num_ldmatrixes"
]}
if
"num_ldmatrixes"
in
config
else
{}
),
**
(
{
"waves_per_eu"
:
config
[
"waves_per_eu"
]}
if
"waves_per_eu"
in
config
else
{}
),
**
(
{
"matrix_instr_nonkdim"
:
config
[
"matrix_instr_nonkdim"
]}
if
"matrix_instr_nonkdim"
in
config
else
{}
),
**
({
"kpack"
:
config
[
"kpack"
]}
if
"kpack"
in
config
else
{}),
}
def
save_configs
(
configs
:
Dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
)
->
None
:
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
def
save_configs
(
configs
:
dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
list
[
int
],
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
None
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
get_config_dtype_str
,
get_config_file_name
)
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename
=
get_config_file_name
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
)
filename
=
get_config_file_name
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
,
block_quant_shape
,
use_nn_moe
=
use_nn_moe
)
print
(
f
"Writing best config to
{
filename
}
..."
)
with
open
(
filename
,
"w"
)
as
f
:
...
...
@@ -270,47 +641,108 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
f
.
write
(
"
\n
"
)
def
get_weight_block_size_safety
(
config
,
default_value
=
None
):
quantization_config
=
getattr
(
config
,
"quantization_config"
,
{})
if
isinstance
(
quantization_config
,
dict
):
return
quantization_config
.
get
(
"weight_block_size"
,
default_value
)
return
default_value
def
main
(
args
:
argparse
.
Namespace
):
import
os
import
logging
from
vllm.platforms
import
current_platform
logger
=
logging
.
getLogger
(
__name__
)
print
(
args
)
tp_size
=
args
.
tp_size
config
=
get_config
(
model
=
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
model_prefix
:
config
=
getattr
(
config
,
args
.
model_prefix
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
(
"DeepseekV3ForCausalLM"
,
"DeepseekV2ForCausalLM"
,
"Glm4MoeForCausalLM"
):
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
(
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
):
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
(
"Step3VLForConditionalGeneration"
):
E
=
config
.
text_config
.
moe_num_experts
topk
=
config
.
text_config
.
moe_top_k
intermediate_size
=
config
.
text_config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
else
:
# Support for llama4
config
=
config
.
get_text_config
()
# Default: Mixtral.
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
hidden_size
=
config
.
hidden_size
dtype
=
config
.
torch_dtype
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
block_quant_shape
=
get_weight_block_size_safety
(
config
)
if
args
.
batch_size
is
None
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
,
]
else
:
batch_sizes
=
[
args
.
batch_size
]
batch_sizes
=
args
.
batch_size
ray
.
init
(
address
=
None
,
ignore_reinit_error
=
True
,
num_gpus
=
args
.
tp_size
)
use_deep_gemm
=
bool
(
args
.
use_deep_gemm
)
if
current_platform
.
is_rocm
()
and
"HIP_VISIBLE_DEVICES"
in
os
.
environ
:
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
logger
.
warning
(
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
)
val
=
os
.
environ
[
"HIP_VISIBLE_DEVICES"
]
os
.
environ
[
"ROCR_VISIBLE_DEVICES"
]
=
val
del
os
.
environ
[
"HIP_VISIBLE_DEVICES"
]
ray
.
init
(
address
=
None
,
ignore_reinit_error
=
True
,
num_gpus
=
args
.
num_gpus
)
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
,
i
)
for
i
in
range
(
num_gpus
)]
def
_distribute
(
method
:
str
,
inputs
:
L
ist
[
Any
])
->
L
ist
[
Any
]:
def
_distribute
(
method
:
str
,
inputs
:
l
ist
[
Any
])
->
l
ist
[
Any
]:
outputs
=
[]
worker_idx
=
0
for
input_args
in
inputs
:
...
...
@@ -322,27 +754,68 @@ def main(args: argparse.Namespace):
return
ray
.
get
(
outputs
)
if
args
.
tune
:
search_space
=
get_configs_compute_bound
()
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
)
search_space
=
get_configs_compute_bound
(
is_fp16
,
block_quant_shape
,
args
.
nn_moe
)
print
(
f
"Start tuning over
{
len
(
search_space
)
}
configurations..."
)
start
=
time
.
time
()
configs
=
_distribute
(
"tune"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
search_space
)
for
batch_size
in
batch_sizes
])
"tune"
,
[
(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
search_space
,
block_quant_shape
,
use_deep_gemm
,
args
.
nn_moe
,
)
for
batch_size
in
batch_sizes
],
)
best_configs
=
{
M
:
sort_config
(
config
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
M
:
sort_config
(
config
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
}
save_configs
(
best_configs
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
)
save_configs
(
best_configs
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
,
use_nn_moe
=
args
.
nn_moe
,
)
end
=
time
.
time
()
print
(
f
"Tuning took
{
end
-
start
:.
2
f
}
seconds"
)
else
:
outputs
=
_distribute
(
"benchmark"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
)
for
batch_size
in
batch_sizes
])
"benchmark"
,
[
(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
,
use_deep_gemm
,
args
.
nn_moe
,
)
for
batch_size
in
batch_sizes
],
)
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
print
(
f
"Batch size:
{
batch_size
}
, config:
{
config
}
"
)
...
...
@@ -351,17 +824,23 @@ def main(args: argparse.Namespace):
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
"--tensor-parallel-size"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--use-deep-gemm"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
"+"
,
required
=
False
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--nn-moe"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--model-prefix"
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--num-gpus"
,
type
=
int
,
default
=
1
)
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_moe_align_block_size.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
itertools
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size_triton
,
)
from
vllm.triton_utils
import
triton
def
get_topk_ids
(
num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
)
->
torch
.
Tensor
:
return
torch
.
stack
(
[
torch
.
randperm
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[:
topk
]
for
_
in
range
(
num_tokens
)
]
)
def
check_correctness
(
num_tokens
,
num_experts
=
256
,
block_size
=
256
,
topk
=
8
):
"""
Verifies vllm vs. Triton
"""
topk_ids
=
get_topk_ids
(
num_tokens
,
num_experts
,
topk
)
# 1. malloc space for triton and vllm
# malloc enough space (max_num_tokens_padded) for the sorted ids
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids_triton
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids_triton
.
fill_
(
topk_ids
.
numel
())
# fill with sentinel value
expert_ids_triton
=
torch
.
zeros
(
(
max_num_tokens_padded
//
block_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
num_tokens_post_pad_triton
=
torch
.
empty
((
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids_vllm
=
torch
.
empty_like
(
sorted_ids_triton
)
sorted_ids_vllm
.
fill_
(
topk_ids
.
numel
())
expert_ids_vllm
=
torch
.
zeros_like
(
expert_ids_triton
)
num_tokens_post_pad_vllm
=
torch
.
empty_like
(
num_tokens_post_pad_triton
)
# 2. run implementations
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_triton
,
expert_ids_triton
,
num_tokens_post_pad_triton
,
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_vllm
,
expert_ids_vllm
,
num_tokens_post_pad_vllm
,
)
print
(
f
"✅ VLLM implementation works with
{
num_experts
}
experts!"
)
# 3. compare results
if
torch
.
allclose
(
expert_ids_triton
,
expert_ids_vllm
)
and
torch
.
allclose
(
num_tokens_post_pad_triton
,
num_tokens_post_pad_vllm
):
print
(
"✅ Triton and VLLM implementations match."
)
else
:
print
(
"❌ Triton and VLLM implementations DO NOT match."
)
print
(
"Triton expert_ids:"
,
expert_ids_triton
)
print
(
"VLLM expert_ids:"
,
expert_ids_vllm
)
print
(
"Triton num_tokens_post_pad:"
,
num_tokens_post_pad_triton
)
print
(
"VLLM num_tokens_post_pad:"
,
num_tokens_post_pad_vllm
)
# test configurations
num_tokens_range
=
[
1
,
16
,
256
,
4096
]
num_experts_range
=
[
16
,
64
,
224
,
256
,
280
,
512
]
topk_range
=
[
1
,
2
,
8
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"triton"
],
# "triton"
line_names
=
[
"VLLM"
,
"Triton"
],
# "Triton"
plot_name
=
"moe-align-block-size-performance"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
num_experts
,
topk
,
provider
):
"""Benchmark function for Triton."""
block_size
=
256
topk_ids
=
get_topk_ids
(
num_tokens
,
num_experts
,
topk
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
num_tokens_post_pad
=
torch
.
empty
((
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"vllm"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
),
quantiles
=
quantiles
,
)
elif
provider
==
"triton"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num_experts"
,
type
=
int
,
default
=
64
,
choices
=
[
8
,
16
,
32
,
64
,
128
,
256
],
)
parser
.
add_argument
(
"--topk"
,
type
=
int
,
default
=
8
,
choices
=
[
2
,
4
,
8
],
help
=
"Top-k value for correctness check."
,
)
args
=
parser
.
parse_args
()
print
(
"Running correctness check..."
)
check_correctness
(
num_tokens
=
1024
,
num_experts
=
args
.
num_experts
,
topk
=
args
.
topk
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
)
benchmarks/kernels/benchmark_moe_int4.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
import
argparse
import
time
from
datetime
import
datetime
from
itertools
import
product
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
TypedDict
import
ray
import
torch
import
triton
from
ray.experimental.tqdm_ray
import
tqdm
from
transformers
import
AutoConfig
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
FP8_DTYPE
=
torch
.
float8_e4m3fnuz
if
current_platform
.
is_rocm
(
)
else
torch
.
float8_e4m3fn
class
BenchmarkConfig
(
TypedDict
):
BLOCK_SIZE_M
:
int
BLOCK_SIZE_N
:
int
BLOCK_SIZE_K
:
int
GROUP_SIZE_M
:
int
num_warps
:
int
num_stages
:
int
num_ldmatrixes
:
Optional
[
int
]
def
benchmark_config
(
config
:
BenchmarkConfig
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
group_size
:
int
,
num_iters
:
int
=
100
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
float
:
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
if
use_int8_w8a16
:
if
not
nn_moe
:
w1
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
),
dtype
=
torch
.
int8
)
w2
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
),
dtype
=
torch
.
int8
)
else
:
w1
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
hidden_size
,
shard_intermediate_size
),
dtype
=
torch
.
int8
)
w2
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
),
dtype
=
torch
.
int8
)
if
use_int4_w4a16
:
w1
=
torch
.
randint
(
0
,
255
,
(
num_experts
,
shard_intermediate_size
,
hidden_size
//
2
,
),
dtype
=
torch
.
uint8
)
w2
=
torch
.
randint
(
0
,
255
,
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
4
,
),
dtype
=
torch
.
uint8
)
else
:
if
not
nn_moe
:
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
)
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
)
else
:
w1
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
,
dtype
=
init_dtype
)
w2
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
,
dtype
=
init_dtype
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
w1_scale
=
None
w2_scale
=
None
a1_scale
=
None
a2_scale
=
None
w1_zp
=
None
w2_zp
=
None
block_shape
=
None
if
use_int8_w8a16
:
w1_scale
=
torch
.
randn
((
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
if
use_fp8_w8a8
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
w1
=
w1
.
to
(
FP8_DTYPE
)
w2
=
w2
.
to
(
FP8_DTYPE
)
if
use_int4_w4a16
:
w1_scale
=
torch
.
randn
((
num_experts
,
shard_intermediate_size
,
hidden_size
//
(
group_size
)),
dtype
=
torch
.
float16
)
w2_scale
=
torch
.
randn
((
num_experts
,
hidden_size
,
shard_intermediate_size
//
(
2
*
group_size
)),
dtype
=
torch
.
float16
)
w1_zp
=
torch
.
randint
(
0
,
255
,
(
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
//
(
group_size
),
),
dtype
=
torch
.
uint8
)
w2_zp
=
torch
.
randint
(
0
,
255
,
(
num_experts
,
hidden_size
//
2
,
shard_intermediate_size
//
(
2
*
group_size
),
),
dtype
=
torch
.
uint8
)
nn_moe
=
False
block_shape
=
[
0
,
group_size
]
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
from
vllm.model_executor.layers.fused_moe
import
override_config
with
override_config
(
config
):
fused_moe
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
use_nn_moe
=
nn_moe
,
block_shape
=
block_shape
,
)
# JIT compilation & warmup
run
()
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
# graph = torch.cuda.CUDAGraph()
# with torch.cuda.graph(graph):
# for _ in range(10):
# run()
# torch.cuda.synchronize()
# Warmup
for
_
in
range
(
5
):
# graph.replay()
run
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
List
[
float
]
=
[]
for
i
in
range
(
num_iters
):
prepare
(
i
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
# graph.replay()
run
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
)
*
1000
# us
# graph.reset()
return
avg
def
get_rocm_tuning_space
(
use_fp16
,
use_int4_w4a16
,
nn_moe
:
Optional
[
bool
]
=
False
):
if
use_int4_w4a16
:
block_m_range
=
[
16
,
32
,
64
]
block_n_range
=
[
32
,
64
,
128
]
block_k_range
=
[
16
,
32
,
64
]
num_warps_range
=
[
1
,
2
,
4
,
8
]
group_m_range
=
[
1
,
4
,
8
,
16
]
num_stage_range
=
[
2
,
4
]
num_ldmatrixes
=
[
0
]
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_m_range
,
"BLOCK_SIZE_N"
:
block_n_range
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stage_range
,
"num_ldmatrixes"
:
num_ldmatrixes
,
}
return
param_ranges
block_mn_range
=
[
16
,
32
,
64
,
128
,
256
]
block_k_range
=
[
16
,
32
,
64
,
128
,
256
]
if
not
use_fp16
:
block_k_range
.
remove
(
16
)
# BLOCK_K=16 not supported for fp8
num_warps_range
=
[
1
,
2
,
4
,
8
]
group_m_range
=
[
1
,
4
,
8
,
16
,
32
]
num_stage_range
=
[
2
]
waves_per_eu_range
=
[
0
]
matrix_instr_nonkdim_range
=
[
16
,
32
]
if
use_fp16
else
[]
kpack_range
=
[
1
,
2
]
if
use_fp16
else
[]
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_mn_range
,
"BLOCK_SIZE_N"
:
block_mn_range
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stage_range
,
"waves_per_eu"
:
waves_per_eu_range
,
}
if
nn_moe
:
param_ranges
[
"num_ldmatrixes"
]
=
1
if
use_fp16
:
param_ranges
[
"matrix_instr_nonkdim"
]
=
matrix_instr_nonkdim_range
param_ranges
[
"kpack"
]
=
kpack_range
return
param_ranges
def
get_configs_compute_bound
(
use_fp16
,
use_int4_w4a16
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
List
[
Dict
[
str
,
int
]]:
configs
:
List
[
BenchmarkConfig
]
=
[]
if
current_platform
.
is_rocm
():
param_ranges
=
get_rocm_tuning_space
(
use_fp16
,
use_int4_w4a16
,
nn_moe
)
else
:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
block_m_range
=
[
16
,
32
,
64
,
128
,
256
]
block_n_range
=
[
32
,
64
,
128
,
256
]
block_k_range
=
[
64
,
128
,
256
]
num_warps_range
=
[
4
,
8
]
group_m_range
=
[
1
,
16
,
32
,
64
]
num_stage_range
=
[
2
,
3
,
4
,
5
]
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_m_range
,
"BLOCK_SIZE_N"
:
block_n_range
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stage_range
,
}
keys
,
values
=
zip
(
*
param_ranges
.
items
())
for
config_values
in
product
(
*
values
):
config
=
dict
(
zip
(
keys
,
config_values
))
configs
.
append
(
config
)
return
configs
def
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
):
N1
,
K1
=
shard_intermediate_size
,
hidden_size
N2
,
K2
=
hidden_size
,
shard_intermediate_size
//
2
pruned_space_1
=
prune_rocm_configs
(
num_tokens
*
2
,
N1
,
K1
,
search_space
,
is_fp16
)
pruned_space_2
=
prune_rocm_configs
(
num_tokens
*
2
,
N2
,
K2
,
search_space
,
is_fp16
)
search_space
=
merge_unique_dicts
(
pruned_space_1
,
pruned_space_2
)
return
search_space
# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def
prune_rocm_configs
(
M
,
N
,
K
,
configs
,
is_fp16
=
True
):
pruned_configs
=
[]
elemBytes_a
=
2
if
is_fp16
else
1
elemBytes_b
=
2
if
is_fp16
else
1
mfma
=
16
if
M
<
32
or
N
<
32
else
32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm
=
False
if
M
>=
2048
and
N
>=
2048
:
large_gemm
=
True
for
config
in
configs
:
BLOCK_SIZE_M
=
config
.
get
(
"BLOCK_SIZE_M"
)
BLOCK_SIZE_N
=
config
.
get
(
"BLOCK_SIZE_N"
)
BLOCK_SIZE_K
=
config
.
get
(
"BLOCK_SIZE_K"
)
num_warps
=
config
.
get
(
"num_warps"
)
if
is_fp16
:
matrix_instr_nonkdim
=
config
.
get
(
"matrix_instr_nonkdim"
)
if
matrix_instr_nonkdim
>
mfma
:
continue
if
mfma
==
4
and
BLOCK_SIZE_K
<
64
:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if
BLOCK_SIZE_M
*
BLOCK_SIZE_N
<
64
:
continue
SPLIT_K
=
config
.
get
(
"SPLIT_K"
,
1
)
GROUP_M
=
config
.
get
(
"GROUP_SIZE_M"
)
if
is_fp16
:
if
(
matrix_instr_nonkdim
>
BLOCK_SIZE_M
or
matrix_instr_nonkdim
>
BLOCK_SIZE_N
):
continue
if
(
matrix_instr_nonkdim
>=
M
and
matrix_instr_nonkdim
!=
BLOCK_SIZE_M
):
continue
if
(
matrix_instr_nonkdim
>=
N
and
matrix_instr_nonkdim
!=
BLOCK_SIZE_N
):
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if
M
*
2
<
BLOCK_SIZE_M
and
BLOCK_SIZE_M
!=
16
:
continue
if
N
*
2
<
BLOCK_SIZE_N
and
BLOCK_SIZE_N
!=
16
:
continue
# skip large split_k when not necessary
if
SPLIT_K
!=
1
and
not
need_split_k
(
M
,
N
,
K
):
continue
# skip split_k that leads to EVEN_K = false
leap
=
SPLIT_K
*
BLOCK_SIZE_K
modv
=
K
%
leap
if
modv
!=
0
:
continue
# skip large GROUP_M
if
GROUP_M
*
BLOCK_SIZE_M
>
M
and
GROUP_M
!=
1
:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS
=
(
BLOCK_SIZE_K
*
BLOCK_SIZE_M
*
elemBytes_a
+
BLOCK_SIZE_K
*
BLOCK_SIZE_N
*
elemBytes_b
)
if
LDS
>
65536
:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if
large_gemm
:
if
BLOCK_SIZE_M
<
64
or
BLOCK_SIZE_N
<
64
:
continue
if
BLOCK_SIZE_K
<
64
:
continue
if
num_warps
<
4
:
continue
pruned_configs
.
append
(
config
)
return
pruned_configs
def
need_split_k
(
SIZE_M
,
SIZE_N
,
SIZE_K
):
return
(
SIZE_M
<
64
or
SIZE_N
<
64
)
and
SIZE_K
>
1024
def
merge_unique_dicts
(
list1
,
list2
):
result
=
[]
combined_list
=
list1
.
copy
()
combined_list
.
extend
(
list2
)
for
dictionary
in
combined_list
:
if
dictionary
not
in
result
:
result
.
append
(
dictionary
)
return
result
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
seed
)
self
.
seed
=
seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self
.
device_id
=
int
(
ray
.
get_gpu_ids
()[
0
])
def
benchmark
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
group_size
:
int
,
)
->
Tuple
[
Dict
[
str
,
int
],
float
]:
current_platform
.
seed_everything
(
self
.
seed
)
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
config_shard_intermediate_size
=
shard_intermediate_size
if
use_int4_w4a16
:
config_shard_intermediate_size
=
shard_intermediate_size
//
2
op_config
=
get_moe_configs
(
num_experts
,
config_shard_intermediate_size
//
2
,
dtype_str
)
if
op_config
is
None
:
config
=
get_default_config
(
num_tokens
,
num_experts
,
config_shard_intermediate_size
,
hidden_size
,
topk
,
dtype_str
,
is_marlin
=
False
)
else
:
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
group_size
)
return
config
,
kernel_time
def
tune
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
group_size
:
int
,
search_space
:
List
[
Dict
[
str
,
int
]],
nn_moe
:
Optional
[
bool
]
=
False
)
->
Dict
[
str
,
int
]:
best_config
=
None
best_time
=
float
(
"inf"
)
if
current_platform
.
is_rocm
():
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
or
use_int4_w4a16
)
search_space
=
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
)
with
torch
.
cuda
.
device
(
self
.
device_id
):
for
config
in
tqdm
(
search_space
):
try
:
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
group_size
,
num_iters
=
20
,
nn_moe
=
nn_moe
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
continue
if
kernel_time
<
best_time
:
best_time
=
kernel_time
best_config
=
config
now
=
datetime
.
now
()
print
(
f
"
{
now
.
ctime
()
}
] Completed tuning for batch_size=
{
num_tokens
}
"
)
assert
best_config
is
not
None
return
best_config
def
sort_config
(
config
:
BenchmarkConfig
)
->
BenchmarkConfig
:
if
"num_ldmatrixes"
not
in
config
:
return
{
"BLOCK_SIZE_M"
:
config
[
"BLOCK_SIZE_M"
],
"BLOCK_SIZE_N"
:
config
[
"BLOCK_SIZE_N"
],
"BLOCK_SIZE_K"
:
config
[
"BLOCK_SIZE_K"
],
"GROUP_SIZE_M"
:
config
[
"GROUP_SIZE_M"
],
"num_warps"
:
config
[
"num_warps"
],
"num_stages"
:
config
[
"num_stages"
],
**
({
"waves_per_eu"
:
config
[
"waves_per_eu"
]
}
if
"waves_per_eu"
in
config
else
{}),
**
({
"matrix_instr_nonkdim"
:
config
[
"matrix_instr_nonkdim"
]
}
if
"matrix_instr_nonkdim"
in
config
else
{}),
**
({
"kpack"
:
config
[
"kpack"
]
}
if
"kpack"
in
config
else
{}),
}
else
:
return
{
"BLOCK_SIZE_M"
:
config
[
"BLOCK_SIZE_M"
],
"BLOCK_SIZE_N"
:
config
[
"BLOCK_SIZE_N"
],
"BLOCK_SIZE_K"
:
config
[
"BLOCK_SIZE_K"
],
"GROUP_SIZE_M"
:
config
[
"GROUP_SIZE_M"
],
"num_warps"
:
config
[
"num_warps"
],
"num_stages"
:
config
[
"num_stages"
],
"num_ldmatrixes"
:
config
[
"num_ldmatrixes"
],
**
({
"waves_per_eu"
:
config
[
"waves_per_eu"
]
}
if
"waves_per_eu"
in
config
else
{}),
**
({
"matrix_instr_nonkdim"
:
config
[
"matrix_instr_nonkdim"
]
}
if
"matrix_instr_nonkdim"
in
config
else
{}),
**
({
"kpack"
:
config
[
"kpack"
]
}
if
"kpack"
in
config
else
{}),
}
def
save_configs
(
configs
:
Dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename
=
get_config_file_name
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
,
use_nn_moe
=
use_nn_moe
)
print
(
f
"Writing best config to
{
filename
}
..."
)
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
configs
,
f
,
indent
=
4
)
f
.
write
(
"
\n
"
)
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
group_size
=
None
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
elif
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
:
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
if
config
.
quantization_config
[
'quant_method'
]
==
"awq"
:
group_size
=
config
.
quantization_config
[
"group_size"
]
else
:
# Default: Mixtral.
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_int4_w4a16
=
args
.
dtype
==
"int4_w4a16"
if
args
.
batch_size
is
None
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
]
else
:
batch_sizes
=
[
args
.
batch_size
]
ray
.
init
(
address
=
None
,
ignore_reinit_error
=
True
,
num_gpus
=
1
)
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
def
_distribute
(
method
:
str
,
inputs
:
List
[
Any
])
->
List
[
Any
]:
outputs
=
[]
worker_idx
=
0
for
input_args
in
inputs
:
worker
=
workers
[
worker_idx
]
worker_method
=
getattr
(
worker
,
method
)
output
=
worker_method
.
remote
(
*
input_args
)
outputs
.
append
(
output
)
worker_idx
=
(
worker_idx
+
1
)
%
num_gpus
return
ray
.
get
(
outputs
)
if
args
.
tune
:
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
or
use_int4_w4a16
)
search_space
=
get_configs_compute_bound
(
is_fp16
,
use_int4_w4a16
,
args
.
nn_moe
)
print
(
f
"Start tuning over
{
len
(
search_space
)
}
configurations..."
)
start
=
time
.
time
()
configs
=
_distribute
(
"tune"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
group_size
,
search_space
,
args
.
nn_moe
)
for
batch_size
in
batch_sizes
])
best_configs
=
{
M
:
sort_config
(
config
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
}
if
use_int4_w4a16
:
save_configs
(
best_configs
,
E
,
shard_intermediate_size
//
2
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_nn_moe
=
args
.
nn_moe
)
else
:
save_configs
(
best_configs
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_nn_moe
=
args
.
nn_moe
)
end
=
time
.
time
()
print
(
f
"Tuning took
{
end
-
start
:.
2
f
}
seconds"
)
else
:
outputs
=
_distribute
(
"benchmark"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
group_size
)
for
batch_size
in
batch_sizes
])
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
print
(
f
"Batch size:
{
batch_size
}
, config:
{
config
}
"
)
print
(
f
"Kernel time:
{
kernel_time
:.
2
f
}
us"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
"--tensor-parallel-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
,
"int4_w4a16"
],
default
=
"int4_w4a16"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--nn_moe"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
,
default
=
True
)
args
=
parser
.
parse_args
()
main
(
args
)
\ No newline at end of file
benchmarks/kernels/benchmark_moe_permute_unpermute.py
0 → 100644
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
from
typing
import
Any
,
TypedDict
import
ray
import
torch
from
transformers
import
AutoConfig
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_moe_permute
,
_moe_unpermute_and_reduce
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
*
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
class
BenchmarkConfig
(
TypedDict
):
BLOCK_SIZE_M
:
int
BLOCK_SIZE_N
:
int
BLOCK_SIZE_K
:
int
GROUP_SIZE_M
:
int
num_warps
:
int
num_stages
:
int
def
benchmark_permute
(
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
use_customized_permute
:
bool
=
False
,
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
# output_hidden_states = torch.empty_like(hidden_states)
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
align_block_size
=
None
qhidden_states
=
hidden_states
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
qhidden_states
,
input_gating
,
topk
,
False
)
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
(
moe_permute
(
qhidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
token_expert_indices
=
token_expert_indices
,
topk
=
topk
,
n_expert
=
num_experts
,
n_local_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
)
else
:
(
permuted_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
_moe_permute
(
qhidden_states
,
None
,
topk_ids
,
num_experts
,
None
,
align_block_size
)
# JIT compilation & warmup
run
()
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
()
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
prepare
(
i
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
def
benchmark_unpermute
(
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
use_customized_permute
:
bool
=
False
,
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
output_hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
align_block_size
=
None
qhidden_states
=
hidden_states
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
qhidden_states
,
input_gating
,
topk
,
False
)
def
prepare
():
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
(
moe_permute
(
qhidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
token_expert_indices
=
token_expert_indices
,
topk
=
topk
,
n_expert
=
num_experts
,
n_local_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
)
# convert to fp16/bf16 as gemm output
return
(
permuted_hidden_states
.
to
(
dtype
),
first_token_off
,
inv_perm_idx
,
m_indices
,
)
else
:
(
permuted_qhidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
_moe_permute
(
qhidden_states
,
None
,
topk_ids
,
num_experts
,
None
,
align_block_size
)
# convert to fp16/bf16 as gemm output
return
(
permuted_qhidden_states
.
to
(
dtype
),
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
def
run
(
input
:
tuple
):
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
input
moe_unpermute
(
permuted_hidden_states
,
topk_weights
,
topk_ids
,
inv_perm_idx
,
first_token_off
,
topk
,
num_experts
,
num_experts
,
)
else
:
(
permuted_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
input
_moe_unpermute_and_reduce
(
output_hidden_states
,
permuted_hidden_states
,
inv_perm
,
topk_weights
)
# JIT compilation & warmup
input
=
prepare
()
run
(
input
)
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
(
input
)
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
seed
)
self
.
seed
=
seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self
.
device_id
=
int
(
ray
.
get_gpu_ids
()[
0
])
def
benchmark
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_customized_permute
:
bool
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
current_platform
.
seed_everything
(
self
.
seed
)
permute_time
=
benchmark_permute
(
num_tokens
,
num_experts
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
use_customized_permute
=
use_customized_permute
,
)
unpermute_time
=
benchmark_unpermute
(
num_tokens
,
num_experts
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
use_customized_permute
=
use_customized_permute
,
)
return
permute_time
,
unpermute_time
def
get_weight_block_size_safety
(
config
,
default_value
=
None
):
quantization_config
=
getattr
(
config
,
"quantization_config"
,
{})
if
isinstance
(
quantization_config
,
dict
):
return
quantization_config
.
get
(
"weight_block_size"
,
default_value
)
return
default_value
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
elif
(
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"Glm4MoeForCausalLM"
):
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
elif
config
.
architectures
[
0
]
in
[
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
]:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
else
:
# Support for llama4
config
=
config
.
get_text_config
()
# Default: Mixtral.
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_customized_permute
=
args
.
use_customized_permute
if
args
.
batch_size
is
None
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
,
]
else
:
batch_sizes
=
[
args
.
batch_size
]
ray
.
init
()
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
def
_distribute
(
method
:
str
,
inputs
:
list
[
Any
])
->
list
[
Any
]:
outputs
=
[]
worker_idx
=
0
for
input_args
in
inputs
:
worker
=
workers
[
worker_idx
]
worker_method
=
getattr
(
worker
,
method
)
output
=
worker_method
.
remote
(
*
input_args
)
outputs
.
append
(
output
)
worker_idx
=
(
worker_idx
+
1
)
%
num_gpus
return
ray
.
get
(
outputs
)
outputs
=
_distribute
(
"benchmark"
,
[
(
batch_size
,
E
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_customized_permute
,
)
for
batch_size
in
batch_sizes
],
)
for
batch_size
,
(
permute
,
unpermute
)
in
zip
(
batch_sizes
,
outputs
):
print
(
f
"Batch size:
{
batch_size
}
"
)
print
(
f
"Permute time:
{
permute
:.
2
f
}
us"
)
print
(
f
"Unpermute time:
{
unpermute
:.
2
f
}
us"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--use-customized-permute"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_paged_attention.py
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
time
from
typing
import
List
,
Optional
from
typing
import
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.
utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentPars
er
,
create_kv_caches_with_random
,
seed_everything
)
from
vllm.
logger
import
init_logg
er
from
vllm.platforms
import
current_platform
import
vllm.envs
as
envs
NUM_BLOCKS
=
1024
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
create_kv_caches_with_random
,
)
logger
=
init_logger
(
__name__
)
NUM_BLOCKS
=
128
*
1024
PARTITION_SIZE
=
512
PARTITION_SIZE_ROCM
=
256
@
torch
.
inference_mode
()
...
...
@@ -29,22 +41,18 @@ def main(
device
:
str
=
"cuda"
,
kv_cache_dtype
:
Optional
[
str
]
=
None
,
)
->
None
:
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
alibi_slopes
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
device
)
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
device
)
seq_lens
=
[
seq_len
for
_
in
range
(
num_seqs
)]
max_seq_len
=
max
(
seq_lens
)
...
...
@@ -52,33 +60,38 @@ def main(
# Create the block tables.
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables_lst
:
L
ist
[
L
ist
[
int
]]
=
[]
block_tables_lst
:
l
ist
[
l
ist
[
int
]]
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables_lst
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables_lst
,
dtype
=
torch
.
int
,
device
=
device
)
block_tables
=
torch
.
tensor
(
block_tables_lst
,
dtype
=
torch
.
int
,
device
=
device
)
# Create the KV cache.
key_caches
,
value_caches
=
create_kv_caches_with_random
(
NUM_BLOCKS
,
block_size
,
1
,
num_kv_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
device
=
device
)
key_caches
,
value_caches
=
create_kv_caches_with_random
(
NUM_BLOCKS
,
block_size
,
1
,
num_kv_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
device
=
device
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Prepare for the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
if
version
==
"v2"
:
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
if
current_platform
.
is_rocm
():
global
PARTITION_SIZE
if
not
args
.
custom_paged_attn
and
not
current_platform
.
is_navi
():
PARTITION_SIZE
=
1024
else
:
PARTITION_SIZE
=
PARTITION_SIZE_ROCM
num_partitions
=
(
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
...
...
@@ -90,6 +103,10 @@ def main(
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
if
version
==
"v12"
:
sliding_window
=
((
-
1
,
-
1
))
logits_soft_cap
=
0.0
def
run_cuda_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
...
...
@@ -98,12 +115,12 @@ def main(
start_time
=
time
.
perf_counter
()
# Using default kv_scale
k_scale
=
v_scale
=
1.0
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
for
_
in
range
(
num_iters
):
if
version
==
"v1"
:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
if
args
.
gc_paged_attn
:
if
args
.
tc_paged_attn
:
ops
.
paged_attention_v1_opt_tc
(
output
,
query
,
...
...
@@ -155,74 +172,110 @@ def main(
v_scale
,
)
elif
version
==
"v2"
:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
ops
.
paged_attention_v2_opt_tc
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
ops
.
paged_attention_v2_opt
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
if
not
args
.
custom_paged_attn
:
if
args
.
gc_paged_attn
:
if
args
.
tc_paged_attn
:
ops
.
paged_attention_v1_opt_tc
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
ops
.
paged_attention_v2_opt
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
ops
.
paged_attention_rocm
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
None
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
elif
version
==
"v12"
:
from
flash_attn
import
vllm_flash_attn_with_kvcache
vllm_flash_attn_with_kvcache
(
q
=
query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
cache_seqlens
=
seq_lens
,
block_table
=
block_tables
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
sliding_window
,
softcap
=
logits_soft_cap
,
alibi_slopes
=
alibi_slopes
,
return_softmax_lse
=
False
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
).
squeeze
(
1
)
else
:
raise
ValueError
(
f
"Invalid version:
{
version
}
"
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerSt
art
()
torch
.
cuda
.
cudart
().
cudaProfilerSt
op
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
...
...
@@ -238,27 +291,29 @@ def main(
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'__main__'
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the paged attention kernel."
)
parser
.
add_argument
(
"--version"
,
type
=
str
,
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
if
__name__
==
"__main__"
:
logger
.
warning
(
"This script benchmarks the paged attention kernel. "
"By default this is no longer used in vLLM inference."
)
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the paged attention kernel."
)
parser
.
add_argument
(
"--version"
,
type
=
str
,
choices
=
[
"v1"
,
"v2"
,
"v12"
],
default
=
"v12"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
,
64
],
default
=
64
)
parser
.
add_argument
(
"--use-alibi"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
...
...
@@ -269,6 +324,15 @@ if __name__ == '__main__':
help
=
"Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (hcu) supports fp8 (=fp8_e4m3)"
)
parser
.
add_argument
(
"--gc-paged-attn"
,
action
=
"store_true"
,
help
=
"Use gc paged attention"
)
parser
.
add_argument
(
"--tc-paged-attn"
,
action
=
"store_true"
,
help
=
"Use tc paged attention"
)
parser
.
add_argument
(
"--custom-paged-attn"
,
action
=
"store_true"
,
help
=
"Use custom paged attention"
)
args
=
parser
.
parse_args
()
print
(
args
)
...
...
benchmarks/kernels/benchmark_quant.py
View file @
c2170174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.
util
s
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
seed_everything
)
from
vllm.
platform
s
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
@
torch
.
inference_mode
()
def
main
(
num_tokens
:
int
,
hidden_size
:
int
,
static_scale
:
bool
,
quant_dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
do_profile
:
bool
=
False
,
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
)
->
None
:
seed_everything
(
seed
)
def
main
(
num_tokens
:
int
,
hidden_size
:
int
,
static_scale
:
bool
,
quant_dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
do_profile
:
bool
=
False
,
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
"cuda"
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
...
...
@@ -38,7 +43,7 @@ def main(num_tokens: int,
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerSt
art
()
torch
.
cuda
.
cudart
().
cudaProfilerSt
op
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
...
...
@@ -54,7 +59,7 @@ def main(num_tokens: int,
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
def
to_torch_dtype
(
dt
):
if
dt
==
"int8"
:
...
...
@@ -64,37 +69,40 @@ if __name__ == '__main__':
raise
ValueError
(
f
"Unsupported dtype:
{
dt
}
"
)
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the quantization (fp8 or int8) kernel."
)
description
=
"Benchmark the quantization (fp8 or int8) kernel."
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--static-scale"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--quant-dtype"
,
type
=
str
,
choices
=
[
"fp8"
,
"int8"
],
default
=
"int8"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--quant-dtype"
,
type
=
str
,
choices
=
[
"fp8"
,
"int8"
],
default
=
"int8"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-warmup-iters"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-iters"
,
type
=
int
,
default
=
100
,
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
)
parser
.
add_argument
(
"--num-iters"
,
type
=
int
,
default
=
100
,
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
,
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
num_tokens
=
args
.
num_tokens
,
hidden_size
=
args
.
hidden_size
,
static_scale
=
args
.
static_scale
,
quant_dtype
=
to_torch_dtype
(
args
.
quant_dtype
),
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
)
main
(
num_tokens
=
args
.
num_tokens
,
hidden_size
=
args
.
hidden_size
,
static_scale
=
args
.
static_scale
,
quant_dtype
=
to_torch_dtype
(
args
.
quant_dtype
),
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
,
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment