Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c676e3d
Commit
4c676e3d
authored
Jun 20, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.1' into v0.9.1-dev
parents
b4c4464d
b6553be1
Changes
418
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1433 additions
and
609 deletions
+1433
-609
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+118
-73
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+292
-201
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+417
-0
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+51
-54
benchmarks/kernels/benchmark_quant.py
benchmarks/kernels/benchmark_quant.py
+39
-33
benchmarks/kernels/benchmark_rmsnorm.py
benchmarks/kernels/benchmark_rmsnorm.py
+26
-34
benchmarks/kernels/benchmark_rope.py
benchmarks/kernels/benchmark_rope.py
+48
-38
benchmarks/kernels/benchmark_shapes.py
benchmarks/kernels/benchmark_shapes.py
+1
-0
benchmarks/kernels/benchmark_w8a8_block_fp8.py
benchmarks/kernels/benchmark_w8a8_block_fp8.py
+54
-60
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
...hmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
+5
-2
benchmarks/kernels/graph_machete_bench.py
benchmarks/kernels/graph_machete_bench.py
+18
-18
benchmarks/kernels/utils.py
benchmarks/kernels/utils.py
+27
-26
benchmarks/kernels/weight_shapes.py
benchmarks/kernels/weight_shapes.py
+47
-0
benchmarks/overheads/benchmark_hashing.py
benchmarks/overheads/benchmark_hashing.py
+20
-17
benchmarks/pyproject.toml
benchmarks/pyproject.toml
+49
-0
benchmarks/run_structured_output_benchmark.sh
benchmarks/run_structured_output_benchmark.sh
+87
-23
cmake/cpu_extension.cmake
cmake/cpu_extension.cmake
+45
-7
cmake/external_projects/vllm_flash_attn.cmake
cmake/external_projects/vllm_flash_attn.cmake
+18
-2
cmake/hipify.py
cmake/hipify.py
+1
-0
cmake/utils.cmake
cmake/utils.cmake
+70
-21
No files found.
Too many changes to show.
To preserve performance only
418 of 418+
files are displayed.
Plain diff
Email patch
benchmarks/kernels/benchmark_marlin.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch
import
torch.utils.benchmark
as
benchmark
import
torch.utils.benchmark
as
benchmark
...
@@ -6,19 +7,34 @@ from benchmark_shapes import WEIGHT_SHAPES
...
@@ -6,19 +7,34 @@ from benchmark_shapes import WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
)
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
,
)
from
vllm.model_executor.layers.quantization.utils.allspark_utils
import
(
from
vllm.model_executor.layers.quantization.utils.allspark_utils
import
(
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
ALLSPARK_SUPPORTED_QUANT_TYPES
)
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
ALLSPARK_SUPPORTED_QUANT_TYPES
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
,
MARLIN_SUPPORTED_GROUP_SIZES
,
query_marlin_supported_quant_types
)
GPTQ_MARLIN_MIN_THREAD_N
,
MARLIN_SUPPORTED_GROUP_SIZES
,
query_marlin_supported_quant_types
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
marlin_quantize
)
MarlinWorkspace
,
marlin_quantize
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
marlin_24_quantize
)
marlin_24_quantize
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
gptq_quantize_weights
,
quantize_weights
,
sort_weights
)
gptq_pack
,
gptq_quantize_weights
,
quantize_weights
,
sort_weights
,
)
from
vllm.scalar_type
import
ScalarType
from
vllm.scalar_type
import
ScalarType
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
...
@@ -29,22 +45,29 @@ ACT_ORDER_OPTS = [False, True]
...
@@ -29,22 +45,29 @@ ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
def
bench_run
(
results
:
list
[
benchmark
.
Measurement
],
model
:
str
,
def
bench_run
(
act_order
:
bool
,
is_k_full
:
bool
,
quant_type
:
ScalarType
,
results
:
list
[
benchmark
.
Measurement
],
group_size
:
int
,
size_m
:
int
,
size_k
:
int
,
size_n
:
int
):
model
:
str
,
act_order
:
bool
,
is_k_full
:
bool
,
quant_type
:
ScalarType
,
group_size
:
int
,
size_m
:
int
,
size_k
:
int
,
size_n
:
int
,
):
label
=
"Quant Matmul"
label
=
"Quant Matmul"
sub_label
=
(
"{}, act={} k_full={}, q={}, g={}, "
sub_label
=
"{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})"
.
format
(
"MKN=({}x{}x{})"
.
format
(
model
,
act_order
,
is_k_full
,
model
,
act_order
,
is_k_full
,
str
(
quant_type
),
group_size
,
size_m
,
size_k
,
size_n
str
(
quant_type
),
group_size
,
size_m
,
)
size_k
,
size_n
))
print
(
f
"Testing:
{
sub_label
}
"
)
print
(
f
"Testing:
{
sub_label
}
"
)
a
=
torch
.
randn
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
()
a
=
torch
.
randn
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
()
b
=
torch
.
rand
(
size_k
,
size_n
).
to
(
torch
.
half
).
cuda
()
b
=
torch
.
rand
(
size_k
,
size_n
).
to
(
torch
.
half
).
cuda
()
a_tmp
=
(
torch
.
zeros
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
()
)
a_tmp
=
torch
.
zeros
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
()
# Marlin quant
# Marlin quant
(
(
...
@@ -57,14 +80,16 @@ def bench_run(results: list[benchmark.Measurement], model: str,
...
@@ -57,14 +80,16 @@ def bench_run(results: list[benchmark.Measurement], model: str,
)
=
marlin_quantize
(
b
,
quant_type
,
group_size
,
act_order
)
)
=
marlin_quantize
(
b
,
quant_type
,
group_size
,
act_order
)
# Marlin_24 quant
# Marlin_24 quant
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
=
(
marlin_24_s
)
=
marlin_24_quantize
(
b
,
quant_type
,
group_size
)
marlin_24_quantize
(
b
,
quant_type
,
group_size
)
)
marlin_zp
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b
.
device
)
marlin_zp
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b
.
device
)
# GPTQ quant
# GPTQ quant
(
w_ref
,
q_w
,
s
,
g_idx
,
(
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
)
=
gptq_quantize_weights
(
rand_perm
)
=
gptq_quantize_weights
(
b
,
quant_type
,
group_size
,
act_order
)
b
,
quant_type
,
group_size
,
act_order
)
q_w_gptq
=
gptq_pack
(
q_w
,
quant_type
.
size_bits
,
size_k
,
size_n
)
q_w_gptq
=
gptq_pack
(
q_w
,
quant_type
.
size_bits
,
size_k
,
size_n
)
# For act_order, sort the "weights" and "g_idx"
# For act_order, sort the "weights" and "g_idx"
...
@@ -74,32 +99,37 @@ def bench_run(results: list[benchmark.Measurement], model: str,
...
@@ -74,32 +99,37 @@ def bench_run(results: list[benchmark.Measurement], model: str,
(
q_w
,
g_idx
,
repack_sort_indices
)
=
sort_weights
(
q_w
,
g_idx
)
(
q_w
,
g_idx
,
repack_sort_indices
)
=
sort_weights
(
q_w
,
g_idx
)
# Prepare
# Prepare
marlin_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
marlin_workspace
=
MarlinWorkspace
(
GPTQ_MARLIN_MAX_PARALLEL
)
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
marlin_24_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
marlin_24_workspace
=
MarlinWorkspace
(
GPTQ_MARLIN_24_MAX_PARALLEL
)
size_n
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
)
marlin_zp
=
torch
.
zeros_like
(
marlin_s
,
dtype
=
torch
.
int
)
marlin_zp
=
torch
.
zeros_like
(
marlin_s
,
dtype
=
torch
.
int
)
# AllSpark W8A16 quant
# AllSpark W8A16 quant
as_supported_case
=
(
quant_type
in
ALLSPARK_SUPPORTED_QUANT_TYPES
as_supported_case
=
(
and
group_size
==
-
1
and
not
act_order
and
is_k_full
)
quant_type
in
ALLSPARK_SUPPORTED_QUANT_TYPES
and
group_size
==
-
1
and
not
act_order
and
is_k_full
)
if
as_supported_case
:
if
as_supported_case
:
properties
=
torch
.
cuda
.
get_device_properties
(
b
.
device
.
index
)
properties
=
torch
.
cuda
.
get_device_properties
(
b
.
device
.
index
)
sm_count
=
properties
.
multi_processor_count
sm_count
=
properties
.
multi_processor_count
sm_version
=
properties
.
major
*
10
+
properties
.
minor
sm_version
=
properties
.
major
*
10
+
properties
.
minor
supported_arch
=
(
sm_version
>=
80
and
sm_version
<
90
)
supported_arch
=
sm_version
>=
80
and
sm_version
<
90
as_supported_case
=
as_supported_case
and
supported_arch
as_supported_case
=
as_supported_case
and
supported_arch
if
supported_arch
:
if
supported_arch
:
has_zp
=
False
has_zp
=
False
w_ref
,
qw
,
s
,
zp
=
quantize_weights
(
b
,
quant_type
,
group_size
,
w_ref
,
qw
,
s
,
zp
=
quantize_weights
(
b
,
quant_type
,
group_size
,
has_zp
)
has_zp
)
qw
=
qw
.
to
(
torch
.
uint8
)
qw
=
qw
.
to
(
torch
.
uint8
)
qw_reorder
,
s_reorder
,
zp_reorder
=
\
qw_reorder
,
s_reorder
,
zp_reorder
=
ops
.
allspark_repack_weight
(
ops
.
allspark_repack_weight
(
qw
,
s
,
zp
,
has_zp
qw
,
s
,
zp
,
has_zp
)
)
CUBLAS_M_THRESHOLD
=
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
CUBLAS_M_THRESHOLD
=
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
globals
=
{
globals
=
{
...
@@ -136,8 +166,7 @@ def bench_run(results: list[benchmark.Measurement], model: str,
...
@@ -136,8 +166,7 @@ def bench_run(results: list[benchmark.Measurement], model: str,
"zp_reorder"
:
zp_reorder
if
as_supported_case
else
None
,
"zp_reorder"
:
zp_reorder
if
as_supported_case
else
None
,
"sm_count"
:
sm_count
if
as_supported_case
else
None
,
"sm_count"
:
sm_count
if
as_supported_case
else
None
,
"sm_version"
:
sm_version
if
as_supported_case
else
None
,
"sm_version"
:
sm_version
if
as_supported_case
else
None
,
"CUBLAS_M_THRESHOLD"
:
"CUBLAS_M_THRESHOLD"
:
CUBLAS_M_THRESHOLD
if
as_supported_case
else
None
,
CUBLAS_M_THRESHOLD
if
as_supported_case
else
None
,
# Kernels
# Kernels
"gptq_marlin_gemm"
:
ops
.
gptq_marlin_gemm
,
"gptq_marlin_gemm"
:
ops
.
gptq_marlin_gemm
,
"gptq_marlin_24_gemm"
:
ops
.
gptq_marlin_24_gemm
,
"gptq_marlin_24_gemm"
:
ops
.
gptq_marlin_24_gemm
,
...
@@ -158,60 +187,63 @@ def bench_run(results: list[benchmark.Measurement], model: str,
...
@@ -158,60 +187,63 @@ def bench_run(results: list[benchmark.Measurement], model: str,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
description
=
"pytorch_gemm"
,
description
=
"pytorch_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
stmt
=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)"
,
# noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)"
,
# noqa: E501
globals
=
globals
,
globals
=
globals
,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_gemm_fp16"
,
description
=
"gptq_marlin_gemm_fp16"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
stmt
=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)"
,
# noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)"
,
# noqa: E501
globals
=
globals
,
globals
=
globals
,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_gemm_fp32"
,
description
=
"gptq_marlin_gemm_fp32"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
if
(
quant_type
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
if
(
and
group_size
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
):
quant_type
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and
group_size
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
):
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
stmt
=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)"
,
# noqa: E501
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)"
,
# noqa: E501
globals
=
globals
,
globals
=
globals
,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_24_gemm"
,
description
=
"gptq_marlin_24_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
stmt
=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)"
,
# noqa: E501
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)"
,
# noqa: E501
globals
=
globals
,
globals
=
globals
,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_repack"
,
description
=
"gptq_marlin_repack"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
if
as_supported_case
:
if
as_supported_case
:
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
stmt
=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)"
,
# noqa: E501
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)"
,
# noqa: E501
globals
=
globals
,
globals
=
globals
,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
description
=
"allspark_w8a16_gemm_fp32"
,
description
=
"allspark_w8a16_gemm_fp32"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
def
main
(
args
):
def
main
(
args
):
...
@@ -233,37 +265,50 @@ def main(args):
...
@@ -233,37 +265,50 @@ def main(args):
continue
continue
for
act_order
in
ACT_ORDER_OPTS
:
for
act_order
in
ACT_ORDER_OPTS
:
if
len
(
args
.
limit_act_order
if
(
)
>
0
and
act_order
not
in
args
.
limit_act_order
:
len
(
args
.
limit_act_order
)
>
0
and
act_order
not
in
args
.
limit_act_order
):
continue
continue
for
is_k_full
in
K_FULL_OPTS
:
for
is_k_full
in
K_FULL_OPTS
:
if
len
(
args
.
limit_k_full
if
(
)
>
0
and
is_k_full
not
in
args
.
limit_k_full
:
len
(
args
.
limit_k_full
)
>
0
and
is_k_full
not
in
args
.
limit_k_full
):
continue
continue
for
quant_type
in
query_marlin_supported_quant_types
(
for
quant_type
in
query_marlin_supported_quant_types
(
False
):
False
):
if
(
if
len
(
args
.
limit_num_bits
)
>
0
and
\
len
(
args
.
limit_num_bits
)
>
0
quant_type
.
size_bits
not
in
args
.
limit_num_bits
:
and
quant_type
.
size_bits
not
in
args
.
limit_num_bits
):
continue
continue
for
group_size
in
MARLIN_SUPPORTED_GROUP_SIZES
:
for
group_size
in
MARLIN_SUPPORTED_GROUP_SIZES
:
if
len
(
if
(
args
.
limit_group_size
len
(
args
.
limit_group_size
)
>
0
)
>
0
and
group_size
not
in
args
.
limit_group_size
:
and
group_size
not
in
args
.
limit_group_size
):
continue
continue
# For act_order, the group_size must be less than
# For act_order, the group_size must be less than
# size_k
# size_k
if
act_order
and
(
group_size
==
size_k
if
act_order
and
(
group_size
==
size_k
or
group_size
==
-
1
):
or
group_size
==
-
1
):
continue
continue
for
size_m
in
args
.
batch_sizes
:
for
size_m
in
args
.
batch_sizes
:
bench_run
(
results
,
model
,
act_order
,
is_k_full
,
bench_run
(
quant_type
,
group_size
,
size_m
,
results
,
size_k
,
size_n
)
model
,
act_order
,
is_k_full
,
quant_type
,
group_size
,
size_m
,
size_k
,
size_n
,
)
compare
=
benchmark
.
Compare
(
results
)
compare
=
benchmark
.
Compare
(
results
)
compare
.
print
()
compare
.
print
()
...
@@ -274,7 +319,8 @@ def main(args):
...
@@ -274,7 +319,8 @@ def main(args):
#
#
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark Marlin across specified models/shapes/batches"
)
description
=
"Benchmark Marlin across specified models/shapes/batches"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--models"
,
"--models"
,
nargs
=
"+"
,
nargs
=
"+"
,
...
@@ -282,10 +328,9 @@ if __name__ == "__main__":
...
@@ -282,10 +328,9 @@ if __name__ == "__main__":
default
=
DEFAULT_MODELS
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
choices
=
WEIGHT_SHAPES
.
keys
(),
)
)
parser
.
add_argument
(
"--batch-sizes"
,
parser
.
add_argument
(
nargs
=
"+"
,
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
type
=
int
,
)
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-group-size"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-group-size"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
...
...
benchmarks/kernels/benchmark_moe.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
argparse
import
json
import
json
...
@@ -10,12 +11,12 @@ from typing import Any, TypedDict
...
@@ -10,12 +11,12 @@ from typing import Any, TypedDict
import
ray
import
ray
import
torch
import
torch
import
triton
from
ray.experimental.tqdm_ray
import
tqdm
from
ray.experimental.tqdm_ray
import
tqdm
from
transformers
import
AutoConfig
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_config
from
vllm.triton_utils
import
triton
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
@@ -31,83 +32,91 @@ class BenchmarkConfig(TypedDict):
...
@@ -31,83 +32,91 @@ class BenchmarkConfig(TypedDict):
num_ldmatrixes
:
Optional
[
int
]
num_ldmatrixes
:
Optional
[
int
]
def
benchmark_config
(
config
:
BenchmarkConfig
,
def
benchmark_config
(
num_tokens
:
int
,
config
:
BenchmarkConfig
,
num_experts
:
int
,
num_tokens
:
int
,
shard_intermediate_size
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
shard_intermediate_size
:
int
,
topk
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
topk
:
int
,
use_fp8_w8a8
:
bool
,
dtype
:
torch
.
dtype
,
use_int8_w8a16
:
bool
,
use_fp8_w8a8
:
bool
,
num_iters
:
int
=
100
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
List
[
int
]
=
None
,
num_iters
:
int
=
100
,
use_deep_gemm
:
bool
=
False
,
block_quant_shape
:
list
[
int
]
=
None
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
float
:
use_deep_gemm
:
bool
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
float
:
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
if
use_int8_w8a16
:
if
use_int8_w8a16
:
if
not
nn_moe
:
if
not
nn_moe
:
w1
=
torch
.
randint
(
-
127
,
w1
=
torch
.
randint
(
127
,
(
-
127
,
num_experts
,
127
,
shard_intermediate_size
,
(
hidden_size
,
num_experts
,
),
shard_intermediate_size
,
dtype
=
torch
.
int8
)
hidden_size
,
w2
=
torch
.
randint
(
-
127
,
),
127
,
(
dtype
=
torch
.
int8
,
num_experts
,
)
hidden_size
,
w2
=
torch
.
randint
(
shard_intermediate_size
//
2
,
-
127
,
),
127
,
dtype
=
torch
.
int8
)
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
),
dtype
=
torch
.
int8
,
)
else
:
else
:
w1
=
torch
.
randint
(
-
127
,
w1
=
torch
.
randint
(
127
,
(
-
127
,
num_experts
,
127
,
hidden_size
,
(
shard_intermediate_size
num_experts
,
),
hidden_size
,
dtype
=
torch
.
int8
)
shard_intermediate_size
,
w2
=
torch
.
randint
(
-
127
,
),
127
,
(
dtype
=
torch
.
int8
,
num_experts
,
)
shard_intermediate_size
//
2
,
w2
=
torch
.
randint
(
hidden_size
-
127
,
),
127
,
dtype
=
torch
.
int8
)
(
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
,
),
dtype
=
torch
.
int8
,
)
else
:
else
:
if
not
nn_moe
:
if
not
nn_moe
:
w1
=
torch
.
randn
(
num_experts
,
w1
=
torch
.
randn
(
shard_intermediate_size
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
hidden_size
,
)
dtype
=
init_dtype
)
w2
=
torch
.
randn
(
w2
=
torch
.
randn
(
num_experts
,
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
hidden_size
,
)
shard_intermediate_size
//
2
,
dtype
=
init_dtype
)
else
:
else
:
w1
=
torch
.
randn
(
num_experts
,
w1
=
torch
.
randn
(
hidden_size
,
num_experts
,
hidden_size
,
shard_intermediate_size
,
dtype
=
init_dtype
shard_intermediate_size
,
)
dtype
=
init_dtype
)
w2
=
torch
.
randn
(
w2
=
torch
.
randn
(
num_experts
,
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
,
dtype
=
init_dtype
shard_intermediate_size
//
2
,
)
hidden_size
,
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
dtype
=
init_dtype
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
w1_scale
=
None
w1_scale
=
None
w2_scale
=
None
w2_scale
=
None
a1_scale
=
None
a1_scale
=
None
a2_scale
=
None
a2_scale
=
None
if
use_int8_w8a16
:
if
use_int8_w8a16
:
w1_scale
=
torch
.
randn
((
num_experts
,
2
*
shard_intermediate_size
),
w1_scale
=
torch
.
randn
(
dtype
=
torch
.
float32
)
(
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
if
block_quant_shape
:
if
block_quant_shape
:
...
@@ -120,10 +129,14 @@ def benchmark_config(config: BenchmarkConfig,
...
@@ -120,10 +129,14 @@ def benchmark_config(config: BenchmarkConfig,
n_tiles_w2
=
(
K
+
block_n
-
1
)
//
block_n
n_tiles_w2
=
(
K
+
block_n
-
1
)
//
block_n
k_tiles_w1
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles_w1
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles_w2
=
(
N
+
block_k
-
1
)
//
block_k
k_tiles_w2
=
(
N
+
block_k
-
1
)
//
block_k
w1_scale
=
torch
.
rand
((
E
,
n_tiles_w1
,
k_tiles_w1
),
w1_scale
=
(
dtype
=
torch
.
float32
)
*
factor_for_scale
torch
.
rand
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
rand
((
E
,
n_tiles_w2
,
k_tiles_w2
),
*
factor_for_scale
dtype
=
torch
.
float32
)
*
factor_for_scale
)
w2_scale
=
(
torch
.
rand
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
)
*
factor_for_scale
)
else
:
else
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
...
@@ -141,10 +154,12 @@ def benchmark_config(config: BenchmarkConfig,
...
@@ -141,10 +154,12 @@ def benchmark_config(config: BenchmarkConfig,
def
run
():
def
run
():
from
vllm.model_executor.layers.fused_moe
import
override_config
from
vllm.model_executor.layers.fused_moe
import
override_config
with
override_config
(
config
):
with
override_config
(
config
):
if
use_deep_gemm
:
if
use_deep_gemm
:
topk_weights
,
topk_ids
=
fused_topk
(
x
,
input_gating
,
topk
,
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
False
)
x
,
input_gating
,
topk
,
False
)
return
fused_experts
(
return
fused_experts
(
x
,
x
,
w1
,
w1
,
...
@@ -249,8 +264,7 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
...
@@ -249,8 +264,7 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
return
param_ranges
return
param_ranges
def
get_configs_compute_bound
(
use_fp16
,
def
get_configs_compute_bound
(
use_fp16
,
block_quant_shape
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
list
[
dict
[
str
,
int
]]:
block_quant_shape
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
list
[
dict
[
str
,
int
]]:
configs
:
list
[
BenchmarkConfig
]
=
[]
configs
:
list
[
BenchmarkConfig
]
=
[]
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
...
@@ -286,20 +300,25 @@ def get_configs_compute_bound(use_fp16,
...
@@ -286,20 +300,25 @@ def get_configs_compute_bound(use_fp16,
if
block_quant_shape
is
not
None
and
not
use_fp16
:
if
block_quant_shape
is
not
None
and
not
use_fp16
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
for
config
in
configs
[:]:
for
config
in
configs
[:]:
if
config
[
"BLOCK_SIZE_K"
]
%
block_k
!=
0
or
config
[
if
(
"BLOCK_SIZE_N"
]
%
block_n
!=
0
:
config
[
"BLOCK_SIZE_K"
]
%
block_k
!=
0
or
config
[
"BLOCK_SIZE_N"
]
%
block_n
!=
0
):
configs
.
remove
(
config
)
configs
.
remove
(
config
)
return
configs
return
configs
def
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
def
prune_rocm_search_space
(
search_space
,
is_fp16
,
topk
):
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
,
topk
):
N1
,
K1
=
shard_intermediate_size
,
hidden_size
N1
,
K1
=
shard_intermediate_size
,
hidden_size
N2
,
K2
=
hidden_size
,
shard_intermediate_size
//
2
N2
,
K2
=
hidden_size
,
shard_intermediate_size
//
2
pruned_space_1
=
prune_rocm_configs
(
num_tokens
*
topk
,
N1
,
K1
,
pruned_space_1
=
prune_rocm_configs
(
search_space
,
is_fp16
)
num_tokens
*
topk
,
N1
,
K1
,
search_space
,
is_fp16
pruned_space_2
=
prune_rocm_configs
(
num_tokens
*
topk
,
N2
,
K2
,
)
search_space
,
is_fp16
)
pruned_space_2
=
prune_rocm_configs
(
num_tokens
*
topk
,
N2
,
K2
,
search_space
,
is_fp16
)
search_space
=
merge_unique_dicts
(
pruned_space_1
,
pruned_space_2
)
search_space
=
merge_unique_dicts
(
pruned_space_1
,
pruned_space_2
)
return
search_space
return
search_space
...
@@ -340,15 +359,16 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
...
@@ -340,15 +359,16 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
# DCU currently does not support matrix_instr_nonkdim param
# DCU currently does not support matrix_instr_nonkdim param
# if is_fp16:
# if is_fp16:
# if (matrix_instr_nonkdim > BLOCK_SIZE_M
# if (
# or matrix_instr_nonkdim > BLOCK_SIZE_N):
# matrix_instr_nonkdim > BLOCK_SIZE_M
# or matrix_instr_nonkdim > BLOCK_SIZE_N
# ):
# continue
# continue
# if (matrix_instr_nonkdim >= M
# if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
# and matrix_instr_nonkdim != BLOCK_SIZE_M):
# continue
# continue
# if (matrix_instr_nonkdim >= N
# if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
# and matrix_instr_nonkdim != BLOCK_SIZE_N):
# continue
# continue
# Skip BLOCK_SIZE that is too large compare to M/N
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
# unless BLOCK_SIZE is already small enough
if
M
*
2
<
BLOCK_SIZE_M
and
BLOCK_SIZE_M
!=
16
:
if
M
*
2
<
BLOCK_SIZE_M
and
BLOCK_SIZE_M
!=
16
:
...
@@ -368,8 +388,10 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
...
@@ -368,8 +388,10 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
continue
continue
# out of shared memory resource
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS
=
(
BLOCK_SIZE_K
*
BLOCK_SIZE_M
*
elemBytes_a
+
LDS
=
(
BLOCK_SIZE_K
*
BLOCK_SIZE_N
*
elemBytes_b
)
BLOCK_SIZE_K
*
BLOCK_SIZE_M
*
elemBytes_a
+
BLOCK_SIZE_K
*
BLOCK_SIZE_N
*
elemBytes_b
)
if
LDS
>
65536
:
if
LDS
>
65536
:
continue
continue
# Skip small block sizes and num_warps for large gemm
# Skip small block sizes and num_warps for large gemm
...
@@ -403,7 +425,6 @@ def merge_unique_dicts(list1, list2):
...
@@ -403,7 +425,6 @@ def merge_unique_dicts(list1, list2):
@
ray
.
remote
(
num_gpus
=
1
)
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
,
device_id
:
int
)
->
None
:
def
__init__
(
self
,
seed
:
int
,
device_id
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda:"
+
str
(
device_id
))
torch
.
set_default_device
(
"cuda:"
+
str
(
device_id
))
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
...
@@ -423,43 +444,47 @@ class BenchmarkWorker:
...
@@ -423,43 +444,47 @@ class BenchmarkWorker:
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
L
ist
[
int
]
=
None
,
block_quant_shape
:
l
ist
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
use_deep_gemm
:
bool
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
)
->
tuple
[
dict
[
str
,
int
],
float
]:
current_platform
.
seed_everything
(
self
.
seed
)
current_platform
.
seed_everything
(
self
.
seed
)
dtype_str
=
get_config_dtype_str
(
dtype
,
dtype_str
=
get_config_dtype_str
(
use_int8_w8a16
=
use_int8_w8a16
,
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
use_fp8_w8a8
=
use_fp8_w8a8
)
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
# is the intermediate size after silu_and_mul.
op_config
=
get_moe_configs
(
num_experts
,
shard_intermediate_size
//
2
,
op_config
=
get_moe_configs
(
dtype_str
,
use_nn_moe
=
nn_moe
)
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
,
use_nn_moe
=
nn_moe
)
if
op_config
is
None
:
if
op_config
is
None
:
config
=
get_default_config
(
num_tokens
,
config
=
get_default_config
(
num_experts
,
num_tokens
,
shard_intermediate_size
,
num_experts
,
hidden_size
,
shard_intermediate_size
,
topk
,
hidden_size
,
dtype_str
,
topk
,
is_marlin
=
False
,
dtype_str
,
use_nn_moe
=
nn_moe
)
is_marlin
=
False
,
use_nn_moe
=
nn_moe
)
else
:
else
:
config
=
op_config
[
min
(
op_config
.
keys
(),
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
kernel_time
=
benchmark_config
(
kernel_time
=
benchmark_config
(
config
,
config
,
num_tokens
,
num_tokens
,
num_experts
,
num_experts
,
shard_intermediate_size
,
shard_intermediate_size
,
hidden_size
,
hidden_size
,
topk
,
topk
,
dtype
,
dtype
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
num_iters
=
100
,
num_iters
=
100
,
block_quant_shape
=
block_quant_shape
,
block_quant_shape
=
block_quant_shape
,
use_deep_gemm
=
use_deep_gemm
,
use_deep_gemm
=
use_deep_gemm
,
nn_moe
=
nn_moe
)
use_nn_moe
=
nn_moe
)
return
config
,
kernel_time
return
config
,
kernel_time
def
tune
(
def
tune
(
...
@@ -481,13 +506,22 @@ class BenchmarkWorker:
...
@@ -481,13 +506,22 @@ class BenchmarkWorker:
best_time
=
float
(
"inf"
)
best_time
=
float
(
"inf"
)
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
)
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
)
search_space
=
prune_rocm_search_space
(
num_tokens
,
search_space
=
prune_rocm_search_space
(
shard_intermediate_size
,
num_tokens
,
hidden_size
,
search_space
,
shard_intermediate_size
,
is_fp16
,
topk
)
hidden_size
,
search_space
,
is_fp16
,
topk
,
)
need_device_guard
=
False
if
current_platform
.
is_rocm
():
visible_device
=
os
.
environ
.
get
(
"ROCR_VISIBLE_DEVICES"
,
None
)
if
visible_device
!=
f
"
{
self
.
device_id
}
"
:
need_device_guard
=
True
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
current_platform
.
is_rocm
(
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
need_device_guard
else
nullcontext
():
)
else
nullcontext
():
for
config
in
tqdm
(
search_space
):
for
config
in
tqdm
(
search_space
):
try
:
try
:
kernel_time
=
benchmark_config
(
kernel_time
=
benchmark_config
(
...
@@ -520,45 +554,48 @@ class BenchmarkWorker:
...
@@ -520,45 +554,48 @@ class BenchmarkWorker:
def
sort_config
(
config
:
BenchmarkConfig
)
->
BenchmarkConfig
:
def
sort_config
(
config
:
BenchmarkConfig
)
->
BenchmarkConfig
:
return
{
return
{
"BLOCK_SIZE_M"
:
"BLOCK_SIZE_M"
:
config
[
"BLOCK_SIZE_M"
],
config
[
"BLOCK_SIZE_M"
],
"BLOCK_SIZE_N"
:
config
[
"BLOCK_SIZE_N"
],
"BLOCK_SIZE_N"
:
"BLOCK_SIZE_K"
:
config
[
"BLOCK_SIZE_K"
],
config
[
"BLOCK_SIZE_N"
],
"GROUP_SIZE_M"
:
config
[
"GROUP_SIZE_M"
],
"BLOCK_SIZE_K"
:
"num_warps"
:
config
[
"num_warps"
],
config
[
"BLOCK_SIZE_K"
],
"num_stages"
:
config
[
"num_stages"
],
"GROUP_SIZE_M"
:
**
(
config
[
"GROUP_SIZE_M"
],
{
"num_ldmatrixes"
:
config
[
"num_ldmatrixes"
]}
if
"num_ldmatrixes"
in
config
else
{}
"num_warps"
:
),
config
[
"num_warps"
],
**
(
"num_stages"
:
{
"waves_per_eu"
:
config
[
"waves_per_eu"
]}
if
"waves_per_eu"
in
config
else
{}
config
[
"num_stages"
],
),
**
({
**
(
"num_ldmatrixes"
:
config
[
"num_ldmatrixes"
]
{
"matrix_instr_nonkdim"
:
config
[
"matrix_instr_nonkdim"
]}
}
if
"num_ldmatrixes"
in
config
else
{}),
if
"matrix_instr_nonkdim"
in
config
**
({
else
{}
"waves_per_eu"
:
config
[
"waves_per_eu"
]
),
}
if
"waves_per_eu"
in
config
else
{}),
**
({
"kpack"
:
config
[
"kpack"
]}
if
"kpack"
in
config
else
{}),
**
({
}
"matrix_instr_nonkdim"
:
config
[
"matrix_instr_nonkdim"
]
}
if
"matrix_instr_nonkdim"
in
config
else
{}),
**
({
"kpack"
:
config
[
"kpack"
]
}
if
"kpack"
in
config
else
{}),
}
def
save_configs
(
configs
:
dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
def
save_configs
(
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
configs
:
dict
[
int
,
BenchmarkConfig
],
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_experts
:
int
,
block_quant_shape
:
List
[
int
],
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
shard_intermediate_size
:
int
,
dtype_str
=
get_config_dtype_str
(
dtype
,
hidden_size
:
int
,
use_int8_w8a16
=
use_int8_w8a16
,
topk
:
int
,
use_fp8_w8a8
=
use_fp8_w8a8
)
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
list
[
int
],
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
None
:
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
# is the intermediate size after silu_and_mul.
filename
=
get_config_file_name
(
num_experts
,
shard_intermediate_size
//
2
,
filename
=
get_config_file_name
(
dtype_str
,
block_quant_shape
,
use_nn_moe
=
use_nn_moe
)
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
,
block_quant_shape
,
use_nn_moe
=
use_nn_moe
)
print
(
f
"Writing best config to
{
filename
}
..."
)
print
(
f
"Writing best config to
{
filename
}
..."
)
with
open
(
filename
,
"w"
)
as
f
:
with
open
(
filename
,
"w"
)
as
f
:
...
@@ -567,21 +604,20 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
...
@@ -567,21 +604,20 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
def
get_weight_block_size_safety
(
config
,
default_value
=
None
):
def
get_weight_block_size_safety
(
config
,
default_value
=
None
):
quantization_config
=
getattr
(
config
,
"quantization_config"
,
{})
quantization_config
=
getattr
(
config
,
'quantization_config'
,
{})
if
isinstance
(
quantization_config
,
dict
):
if
isinstance
(
quantization_config
,
dict
):
return
quantization_config
.
get
(
'
weight_block_size
'
,
default_value
)
return
quantization_config
.
get
(
"
weight_block_size
"
,
default_value
)
return
default_value
return
default_value
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
print
(
args
)
tp_size
=
args
.
tp_size
tp_size
=
args
.
tp_size
config
=
get_config
(
model
=
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
model_prefix
:
config
=
getattr
(
config
,
args
.
model_prefix
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
topk
=
config
.
ffn_config
.
moe_top_k
...
@@ -592,15 +628,12 @@ def main(args: argparse.Namespace):
...
@@ -592,15 +628,12 @@ def main(args: argparse.Namespace):
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
(
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
elif
config
.
architectures
[
0
]
in
(
"DeepseekV3ForCausalLM"
,
"DeepseekV2ForCausalLM"
):
or
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
):
E
=
config
.
n_routed_experts
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
[
elif
config
.
architectures
[
0
]
in
(
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
):
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
]:
E
=
config
.
num_experts
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
...
@@ -622,18 +655,41 @@ def main(args: argparse.Namespace):
...
@@ -622,18 +655,41 @@ def main(args: argparse.Namespace):
if
args
.
batch_size
is
None
:
if
args
.
batch_size
is
None
:
batch_sizes
=
[
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
1
,
2048
,
3072
,
4096
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
,
]
]
else
:
else
:
batch_sizes
=
[
args
.
batch_size
]
batch_sizes
=
[
args
.
batch_size
]
use_deep_gemm
=
bool
(
args
.
use_deep_gemm
)
use_deep_gemm
=
bool
(
args
.
use_deep_gemm
)
ray
.
init
(
address
=
None
,
if
current_platform
.
is_rocm
()
and
"HIP_VISIBLE_DEVICES"
in
os
.
environ
:
ignore_reinit_error
=
True
,
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
num_gpus
=
args
.
num_gpus
)
logger
.
warning
(
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
)
val
=
os
.
environ
[
"HIP_VISIBLE_DEVICES"
]
os
.
environ
[
"ROCR_VISIBLE_DEVICES"
]
=
val
del
os
.
environ
[
"HIP_VISIBLE_DEVICES"
]
ray
.
init
(
address
=
None
,
ignore_reinit_error
=
True
,
num_gpus
=
args
.
num_gpus
)
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
,
i
)
for
i
in
range
(
num_gpus
)]
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
,
i
)
for
i
in
range
(
num_gpus
)]
...
@@ -655,25 +711,62 @@ def main(args: argparse.Namespace):
...
@@ -655,25 +711,62 @@ def main(args: argparse.Namespace):
start
=
time
.
time
()
start
=
time
.
time
()
configs
=
_distribute
(
configs
=
_distribute
(
"tune"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
"tune"
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
search_space
,
[
block_quant_shape
,
use_deep_gemm
,
args
.
nn_moe
)
(
for
batch_size
in
batch_sizes
])
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
search_space
,
block_quant_shape
,
use_deep_gemm
,
args
.
nn_moe
,
)
for
batch_size
in
batch_sizes
],
)
best_configs
=
{
best_configs
=
{
M
:
sort_config
(
config
)
M
:
sort_config
(
config
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
}
}
save_configs
(
best_configs
,
E
,
shard_intermediate_size
,
hidden_size
,
save_configs
(
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
best_configs
,
block_quant_shape
,
use_nn_moe
=
args
.
nn_moe
)
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
,
use_nn_moe
=
args
.
nn_moe
,
)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
f
"Tuning took
{
end
-
start
:.
2
f
}
seconds"
)
print
(
f
"Tuning took
{
end
-
start
:.
2
f
}
seconds"
)
else
:
else
:
outputs
=
_distribute
(
outputs
=
_distribute
(
"benchmark"
,
"benchmark"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
[
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
,
use_deep_gemm
,
args
.
nn_moe
)
(
for
batch_size
in
batch_sizes
])
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
,
use_deep_gemm
,
args
.
nn_moe
,
)
for
batch_size
in
batch_sizes
],
)
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
print
(
f
"Batch size:
{
batch_size
}
, config:
{
config
}
"
)
print
(
f
"Batch size:
{
batch_size
}
, config:
{
config
}
"
)
...
@@ -682,24 +775,22 @@ def main(args: argparse.Namespace):
...
@@ -682,24 +775,22 @@ def main(args: argparse.Namespace):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
parser
.
add_argument
(
type
=
str
,
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
)
parser
.
add_argument
(
"--tp-size"
,
parser
.
add_argument
(
"-tp"
,
"--tp-size"
,
"-tp"
,
"--tensor-parallel-size"
,
type
=
int
,
default
=
2
"--tensor-parallel-size"
,
)
type
=
int
,
parser
.
add_argument
(
default
=
2
)
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
parser
.
add_argument
(
"--dtype"
,
)
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--use-deep-gemm"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--use-deep-gemm"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--nn-moe"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--nn-moe"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--model-prefix"
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--num-gpus"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--num-gpus"
,
type
=
int
,
default
=
1
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
benchmarks/kernels/benchmark_moe_permute_unpermute.py
0 → 100644
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
from
typing
import
Any
,
TypedDict
import
ray
import
torch
from
transformers
import
AutoConfig
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_moe_permute
,
_moe_unpermute_and_reduce
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
*
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
class
BenchmarkConfig
(
TypedDict
):
BLOCK_SIZE_M
:
int
BLOCK_SIZE_N
:
int
BLOCK_SIZE_K
:
int
GROUP_SIZE_M
:
int
num_warps
:
int
num_stages
:
int
def
benchmark_permute
(
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
use_customized_permute
:
bool
=
False
,
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
# output_hidden_states = torch.empty_like(hidden_states)
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
align_block_size
=
None
qhidden_states
=
hidden_states
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
qhidden_states
,
input_gating
,
topk
,
False
)
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
(
moe_permute
(
qhidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
token_expert_indices
=
token_expert_indices
,
topk
=
topk
,
n_expert
=
num_experts
,
n_local_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
)
else
:
(
permuted_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
_moe_permute
(
qhidden_states
,
None
,
topk_ids
,
num_experts
,
None
,
align_block_size
)
# JIT compilation & warmup
run
()
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
()
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
prepare
(
i
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
def
benchmark_unpermute
(
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
use_customized_permute
:
bool
=
False
,
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
output_hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
align_block_size
=
None
qhidden_states
=
hidden_states
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
qhidden_states
,
input_gating
,
topk
,
False
)
def
prepare
():
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
(
moe_permute
(
qhidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
token_expert_indices
=
token_expert_indices
,
topk
=
topk
,
n_expert
=
num_experts
,
n_local_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
)
# convert to fp16/bf16 as gemm output
return
(
permuted_hidden_states
.
to
(
dtype
),
first_token_off
,
inv_perm_idx
,
m_indices
,
)
else
:
(
permuted_qhidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
_moe_permute
(
qhidden_states
,
None
,
topk_ids
,
num_experts
,
None
,
align_block_size
)
# convert to fp16/bf16 as gemm output
return
(
permuted_qhidden_states
.
to
(
dtype
),
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
def
run
(
input
:
tuple
):
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
input
moe_unpermute
(
permuted_hidden_states
,
topk_weights
,
topk_ids
,
inv_perm_idx
,
first_token_off
,
topk
,
num_experts
,
num_experts
,
)
else
:
(
permuted_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
input
_moe_unpermute_and_reduce
(
output_hidden_states
,
permuted_hidden_states
,
inv_perm
,
topk_weights
)
# JIT compilation & warmup
input
=
prepare
()
run
(
input
)
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
(
input
)
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
seed
)
self
.
seed
=
seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self
.
device_id
=
int
(
ray
.
get_gpu_ids
()[
0
])
def
benchmark
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_customized_permute
:
bool
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
current_platform
.
seed_everything
(
self
.
seed
)
permute_time
=
benchmark_permute
(
num_tokens
,
num_experts
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
use_customized_permute
=
use_customized_permute
,
)
unpermute_time
=
benchmark_unpermute
(
num_tokens
,
num_experts
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
use_customized_permute
=
use_customized_permute
,
)
return
permute_time
,
unpermute_time
def
get_weight_block_size_safety
(
config
,
default_value
=
None
):
quantization_config
=
getattr
(
config
,
"quantization_config"
,
{})
if
isinstance
(
quantization_config
,
dict
):
return
quantization_config
.
get
(
"weight_block_size"
,
default_value
)
return
default_value
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
elif
(
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
):
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
elif
config
.
architectures
[
0
]
in
[
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
]:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
else
:
# Support for llama4
config
=
config
.
get_text_config
()
# Default: Mixtral.
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_customized_permute
=
args
.
use_customized_permute
if
args
.
batch_size
is
None
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
,
]
else
:
batch_sizes
=
[
args
.
batch_size
]
ray
.
init
()
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
def
_distribute
(
method
:
str
,
inputs
:
list
[
Any
])
->
list
[
Any
]:
outputs
=
[]
worker_idx
=
0
for
input_args
in
inputs
:
worker
=
workers
[
worker_idx
]
worker_method
=
getattr
(
worker
,
method
)
output
=
worker_method
.
remote
(
*
input_args
)
outputs
.
append
(
output
)
worker_idx
=
(
worker_idx
+
1
)
%
num_gpus
return
ray
.
get
(
outputs
)
outputs
=
_distribute
(
"benchmark"
,
[
(
batch_size
,
E
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_customized_permute
,
)
for
batch_size
in
batch_sizes
],
)
for
batch_size
,
(
permute
,
unpermute
)
in
zip
(
batch_sizes
,
outputs
):
print
(
f
"Batch size:
{
batch_size
}
"
)
print
(
f
"Permute time:
{
permute
:.
2
f
}
us"
)
print
(
f
"Unpermute time:
{
unpermute
:.
2
f
}
us"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--use-customized-permute"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_paged_attention.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
random
import
time
import
time
...
@@ -9,10 +10,13 @@ import torch
...
@@ -9,10 +10,13 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
create_kv_caches_with_random
)
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
create_kv_caches_with_random
,
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -40,19 +44,15 @@ def main(
...
@@ -40,19 +44,15 @@ def main(
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
empty
(
num_seqs
,
query
=
torch
.
empty
(
num_query_heads
,
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
head_size
,
)
dtype
=
dtype
,
device
=
device
)
query
.
uniform_
(
-
scale
,
scale
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
assert
num_query_heads
%
num_kv_heads
==
0
alibi_slopes
=
None
alibi_slopes
=
None
if
use_alibi
:
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
device
)
dtype
=
torch
.
float
,
device
=
device
)
seq_lens
=
[
seq_len
for
_
in
range
(
num_seqs
)]
seq_lens
=
[
seq_len
for
_
in
range
(
num_seqs
)]
max_seq_len
=
max
(
seq_lens
)
max_seq_len
=
max
(
seq_lens
)
...
@@ -63,24 +63,23 @@ def main(
...
@@ -63,24 +63,23 @@ def main(
block_tables_lst
:
list
[
list
[
int
]]
=
[]
block_tables_lst
:
list
[
list
[
int
]]
=
[]
for
_
in
range
(
num_seqs
):
for
_
in
range
(
num_seqs
):
block_table
=
[
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
]
block_tables_lst
.
append
(
block_table
)
block_tables_lst
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables_lst
,
block_tables
=
torch
.
tensor
(
block_tables_lst
,
dtype
=
torch
.
int
,
device
=
device
)
dtype
=
torch
.
int
,
device
=
device
)
# Create the KV cache.
# Create the KV cache.
key_caches
,
value_caches
=
create_kv_caches_with_random
(
NUM_BLOCKS
,
key_caches
,
value_caches
=
create_kv_caches_with_random
(
block_size
,
NUM_BLOCKS
,
1
,
block_size
,
num_kv_heads
,
1
,
head_size
,
num_kv_heads
,
kv_cache_dtype
,
head_size
,
dtype
,
kv_cache_dtype
,
device
=
device
)
dtype
,
device
=
device
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Prepare for the paged attention kernel.
# Prepare for the paged attention kernel.
...
@@ -88,11 +87,11 @@ def main(
...
@@ -88,11 +87,11 @@ def main(
if
version
==
"v2"
:
if
version
==
"v2"
:
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
global
PARTITION_SIZE
global
PARTITION_SIZE
if
not
args
.
custom_paged_attn
:
if
not
args
.
custom_paged_attn
and
not
current_platform
.
is_navi
()
:
PARTITION_SIZE
=
1024
PARTITION_SIZE
=
1024
else
:
else
:
PARTITION_SIZE
=
PARTITION_SIZE_ROCM
PARTITION_SIZE
=
PARTITION_SIZE_ROCM
num_partitions
=
(
(
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
num_partitions
=
(
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
tmp_output
=
torch
.
empty
(
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
dtype
=
output
.
dtype
,
...
@@ -112,9 +111,7 @@ def main(
...
@@ -112,9 +111,7 @@ def main(
start_time
=
time
.
perf_counter
()
start_time
=
time
.
perf_counter
()
# Using default kv_scale
# Using default kv_scale
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
dtype
=
torch
.
float32
,
device
=
device
)
for
_
in
range
(
num_iters
):
for
_
in
range
(
num_iters
):
if
version
==
"v1"
:
if
version
==
"v1"
:
...
@@ -242,6 +239,7 @@ def main(
...
@@ -242,6 +239,7 @@ def main(
scale
,
scale
,
block_tables
,
block_tables
,
seq_lens
,
seq_lens
,
None
,
block_size
,
block_size
,
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
...
@@ -271,30 +269,29 @@ def main(
...
@@ -271,30 +269,29 @@ def main(
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
logger
.
warning
(
"This script benchmarks the paged attention kernel. "
logger
.
warning
(
"By default this is no longer used in vLLM inference."
)
"This script benchmarks the paged attention kernel. "
"By default this is no longer used in vLLM inference."
)
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the paged attention kernel."
)
description
=
"Benchmark the paged attention kernel."
)
parser
.
add_argument
(
"--version"
,
type
=
str
,
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
parser
.
add_argument
(
"--version"
,
type
=
str
,
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
parser
.
add_argument
(
type
=
int
,
"--head-size"
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
type
=
int
,
default
=
128
)
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--use-alibi"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--use-alibi"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
parser
.
add_argument
(
type
=
str
,
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
)
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -305,15 +302,15 @@ if __name__ == '__main__':
...
@@ -305,15 +302,15 @@ if __name__ == '__main__':
help
=
"Data type for kv cache storage. If 'auto', will use model "
help
=
"Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (hcu) supports fp8 (=fp8_e4m3)"
)
"ROCm (hcu) supports fp8 (=fp8_e4m3)"
)
parser
.
add_argument
(
"--gc-paged-attn"
,
parser
.
add_argument
(
action
=
"store_true"
,
"--gc-paged-attn"
,
action
=
"store_true"
,
help
=
"Use gc paged attention"
help
=
"Use gc paged attention"
)
)
parser
.
add_argument
(
"--tc-paged-attn"
,
parser
.
add_argument
(
action
=
"store_true"
,
"--tc-paged-attn"
,
action
=
"store_true"
,
help
=
"Use tc paged attention"
help
=
"Use tc paged attention"
)
)
parser
.
add_argument
(
"--custom-paged-attn"
,
parser
.
add_argument
(
action
=
"store_true"
,
"--custom-paged-attn"
,
action
=
"store_true"
,
help
=
"Use custom paged attention"
help
=
"Use custom paged attention"
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
...
...
benchmarks/kernels/benchmark_quant.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
time
...
@@ -10,15 +11,17 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
...
@@ -10,15 +11,17 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
main
(
num_tokens
:
int
,
def
main
(
hidden_size
:
int
,
num_tokens
:
int
,
static_scale
:
bool
,
hidden_size
:
int
,
quant_dtype
:
torch
.
dtype
,
static_scale
:
bool
,
dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
dtype
:
torch
.
dtype
,
do_profile
:
bool
=
False
,
seed
:
int
=
0
,
num_warmup_iters
:
int
=
5
,
do_profile
:
bool
=
False
,
num_iters
:
int
=
100
)
->
None
:
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
...
@@ -56,7 +59,7 @@ def main(num_tokens: int,
...
@@ -56,7 +59,7 @@ def main(num_tokens: int,
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
def
to_torch_dtype
(
dt
):
def
to_torch_dtype
(
dt
):
if
dt
==
"int8"
:
if
dt
==
"int8"
:
...
@@ -66,37 +69,40 @@ if __name__ == '__main__':
...
@@ -66,37 +69,40 @@ if __name__ == '__main__':
raise
ValueError
(
f
"Unsupported dtype:
{
dt
}
"
)
raise
ValueError
(
f
"Unsupported dtype:
{
dt
}
"
)
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the quantization (fp8 or int8) kernel."
)
description
=
"Benchmark the quantization (fp8 or int8) kernel."
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--static-scale"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--static-scale"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--quant-dtype"
,
parser
.
add_argument
(
type
=
str
,
"--quant-dtype"
,
type
=
str
,
choices
=
[
"fp8"
,
"int8"
],
default
=
"int8"
choices
=
[
"fp8"
,
"int8"
],
)
default
=
"int8"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--dtype"
,
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
type
=
str
,
)
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-warmup-iters"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-warmup-iters"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-iters"
,
parser
.
add_argument
(
type
=
int
,
"--num-iters"
,
default
=
100
,
type
=
int
,
help
=
"Number of benchmark iterations. "
default
=
100
,
"If --profile is set, this number is ignored"
)
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
main
(
num_tokens
=
args
.
num_tokens
,
main
(
hidden_size
=
args
.
hidden_size
,
num_tokens
=
args
.
num_tokens
,
static_scale
=
args
.
static_scale
,
hidden_size
=
args
.
hidden_size
,
quant_dtype
=
to_torch_dtype
(
args
.
quant_dtype
),
static_scale
=
args
.
static_scale
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
quant_dtype
=
to_torch_dtype
(
args
.
quant_dtype
),
seed
=
args
.
seed
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
do_profile
=
args
.
profile
,
seed
=
args
.
seed
,
num_warmup_iters
=
args
.
num_warmup_iters
,
do_profile
=
args
.
profile
,
num_iters
=
args
.
num_iters
)
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
,
)
benchmarks/kernels/benchmark_rmsnorm.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
itertools
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
import
torch
import
torch
import
triton
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
from
torch
import
nn
from
torch
import
nn
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm.triton_utils
import
triton
class
HuggingFaceRMSNorm
(
nn
.
Module
):
class
HuggingFaceRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
)
->
None
:
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
...
@@ -114,23 +114,19 @@ def rmsnorm_vllm(
...
@@ -114,23 +114,19 @@ def rmsnorm_vllm(
def
calculate_diff
(
batch_size
,
seq_len
,
hidden_size
,
use_residual
=
True
):
def
calculate_diff
(
batch_size
,
seq_len
,
hidden_size
,
use_residual
=
True
):
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
x
=
torch
.
randn
(
batch_size
,
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
output_naive
=
rmsnorm_naive
(
output_naive
=
rmsnorm_naive
(
x
.
clone
(),
weight
,
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
residual
.
clone
()
if
residual
is
not
None
else
None
)
)
output_flashinfer
=
rmsnorm_flashinfer
(
output_flashinfer
=
rmsnorm_flashinfer
(
x
.
clone
(),
weight
,
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
residual
.
clone
()
if
residual
is
not
None
else
None
)
)
output_vllm
=
rmsnorm_vllm
(
output_vllm
=
rmsnorm_vllm
(
x
.
clone
(),
weight
,
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
residual
.
clone
()
if
residual
is
not
None
else
None
)
)
if
use_residual
:
if
use_residual
:
output_naive
=
output_naive
[
0
]
output_naive
=
output_naive
[
0
]
...
@@ -141,9 +137,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
...
@@ -141,9 +137,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
print
(
f
"FlashInfer output=
{
output_flashinfer
}
"
)
print
(
f
"FlashInfer output=
{
output_flashinfer
}
"
)
print
(
f
"vLLM output=
{
output_vllm
}
"
)
print
(
f
"vLLM output=
{
output_vllm
}
"
)
if
torch
.
allclose
(
output_naive
,
output_flashinfer
,
atol
=
1e-2
,
if
torch
.
allclose
(
rtol
=
1e-2
)
and
torch
.
allclose
(
output_naive
,
output_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
):
)
and
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ All implementations match"
)
print
(
"✅ All implementations match"
)
else
:
else
:
print
(
"❌ Implementations differ"
)
print
(
"❌ Implementations differ"
)
...
@@ -152,12 +148,10 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
...
@@ -152,12 +148,10 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
seq_length_range
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
seq_length_range
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
head_num_range
=
[
32
,
48
]
head_num_range
=
[
32
,
48
]
configs
=
list
(
configs
=
list
(
itertools
.
product
(
head_num_range
,
batch_size_range
,
seq_length_range
))
itertools
.
product
(
head_num_range
,
batch_size_range
,
seq_length_range
))
def
get_benchmark
(
use_residual
):
def
get_benchmark
(
use_residual
):
@
triton
.
testing
.
perf_report
(
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"head_num"
,
"batch_size"
,
"seq_len"
],
x_names
=
[
"head_num"
,
"batch_size"
,
"seq_len"
],
...
@@ -167,19 +161,15 @@ def get_benchmark(use_residual):
...
@@ -167,19 +161,15 @@ def get_benchmark(use_residual):
line_names
=
[
"HuggingFace"
,
"FlashInfer"
,
"vLLM"
],
line_names
=
[
"HuggingFace"
,
"FlashInfer"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
ylabel
=
"us"
,
plot_name
=
plot_name
=
f
"rmsnorm-perf-
{
'with'
if
use_residual
else
'without'
}
-residual"
,
f
"rmsnorm-perf-
{
'with'
if
use_residual
else
'without'
}
-residual"
,
args
=
{},
args
=
{},
))
)
)
def
benchmark
(
head_num
,
batch_size
,
seq_len
,
provider
):
def
benchmark
(
head_num
,
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
hidden_size
=
head_num
*
128
# assuming head_dim = 128
hidden_size
=
head_num
*
128
# assuming head_dim = 128
x
=
torch
.
randn
(
batch_size
,
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
...
@@ -240,9 +230,9 @@ if __name__ == "__main__":
...
@@ -240,9 +230,9 @@ if __name__ == "__main__":
default
=
4096
,
default
=
4096
,
help
=
"Hidden size (2nd dimension) of the sequence"
,
help
=
"Hidden size (2nd dimension) of the sequence"
,
)
)
parser
.
add_argument
(
"--use-residual"
,
parser
.
add_argument
(
action
=
"store_true"
,
"--use-residual"
,
action
=
"store_true"
,
help
=
"Whether to use residual connection"
help
=
"Whether to use residual connection"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--save-path"
,
"--save-path"
,
type
=
str
,
type
=
str
,
...
@@ -253,10 +243,12 @@ if __name__ == "__main__":
...
@@ -253,10 +243,12 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# Run correctness test
# Run correctness test
calculate_diff
(
batch_size
=
args
.
batch_size
,
calculate_diff
(
seq_len
=
args
.
seq_len
,
batch_size
=
args
.
batch_size
,
hidden_size
=
args
.
hidden_size
,
seq_len
=
args
.
seq_len
,
use_residual
=
args
.
use_residual
)
hidden_size
=
args
.
hidden_size
,
use_residual
=
args
.
use_residual
,
)
# Get the benchmark function with proper use_residual setting
# Get the benchmark function with proper use_residual setting
benchmark
=
get_benchmark
(
args
.
use_residual
)
benchmark
=
get_benchmark
(
args
.
use_residual
)
...
...
benchmarks/kernels/benchmark_rope.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
itertools
import
accumulate
from
itertools
import
accumulate
from
typing
import
Optional
from
typing
import
Optional
...
@@ -6,8 +7,7 @@ from typing import Optional
...
@@ -6,8 +7,7 @@ from typing import Optional
import
nvtx
import
nvtx
import
torch
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
(
RotaryEmbedding
,
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
,
get_rope
get_rope
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
...
@@ -23,7 +23,7 @@ def benchmark_rope_kernels_multi_lora(
...
@@ -23,7 +23,7 @@ def benchmark_rope_kernels_multi_lora(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
max_position
:
int
=
8192
,
max_position
:
int
=
8192
,
base
:
in
t
=
10000
,
base
:
floa
t
=
10000
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -32,40 +32,49 @@ def benchmark_rope_kernels_multi_lora(
...
@@ -32,40 +32,49 @@ def benchmark_rope_kernels_multi_lora(
# silulating serving 4 LoRAs
# silulating serving 4 LoRAs
scaling_factors
=
[
1
,
2
,
4
,
8
]
scaling_factors
=
[
1
,
2
,
4
,
8
]
# batched RoPE can take multiple scaling factors
# batched RoPE can take multiple scaling factors
batched_rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
batched_rope
=
get_rope
(
is_neox_style
,
{
head_size
,
"rope_type"
:
"linear"
,
rotary_dim
,
"factor"
:
tuple
(
scaling_factors
)
max_position
,
})
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
tuple
(
scaling_factors
)},
)
# non-batched RoPE takes only one scaling factor, we create multiple
# non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior
# instances to simulate the same behavior
non_batched_ropes
:
list
[
RotaryEmbedding
]
=
[]
non_batched_ropes
:
list
[
RotaryEmbedding
]
=
[]
for
scaling_factor
in
scaling_factors
:
for
scaling_factor
in
scaling_factors
:
non_batched_ropes
.
append
(
non_batched_ropes
.
append
(
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
get_rope
(
{
head_size
,
"rope_type"
:
"linear"
,
rotary_dim
,
"factor"
:
(
scaling_factor
,
)
max_position
,
}))
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
(
scaling_factor
,)},
)
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query
=
torch
.
randn
(
batch_size
,
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
key
=
torch
.
randn_like
(
query
)
# create query offsets for batched RoPE, we concat multiple kv cache
# create query offsets for batched RoPE, we concat multiple kv cache
# together and each query needs to find the right kv cache of its type
# together and each query needs to find the right kv cache of its type
offset_map
=
torch
.
tensor
(
offset_map
=
torch
.
tensor
(
list
(
list
(
accumulate
([
0
]
+
[
accumulate
(
max_position
*
scaling_factor
*
2
[
0
]
for
scaling_factor
in
scaling_factors
[:
-
1
]
+
[
])))
max_position
*
scaling_factor
*
2
query_types
=
torch
.
randint
(
0
,
for
scaling_factor
in
scaling_factors
[:
-
1
]
len
(
scaling_factors
),
(
batch_size
,
seq_len
),
]
device
=
device
)
)
)
)
query_types
=
torch
.
randint
(
0
,
len
(
scaling_factors
),
(
batch_size
,
seq_len
),
device
=
device
)
# map query types to offsets
# map query types to offsets
query_offsets
=
offset_map
[
query_types
]
query_offsets
=
offset_map
[
query_types
]
# the kernel takes flattened offsets
# the kernel takes flattened offsets
...
@@ -86,27 +95,28 @@ def benchmark_rope_kernels_multi_lora(
...
@@ -86,27 +95,28 @@ def benchmark_rope_kernels_multi_lora(
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the rotary embedding kernels."
)
description
=
"Benchmark the rotary embedding kernels."
)
parser
.
add_argument
(
"--is-neox-style"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--is-neox-style"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
parser
.
add_argument
(
type
=
int
,
"--head-size"
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
type
=
int
,
default
=
128
)
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--rotary-dim"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
32
)
parser
.
add_argument
(
"--rotary-dim"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
32
)
parser
.
add_argument
(
"--dtype"
,
parser
.
add_argument
(
type
=
str
,
"--dtype"
,
type
=
str
,
choices
=
[
"bfloat16"
,
"float"
],
default
=
"float"
choices
=
[
"bfloat16"
,
"float"
],
)
default
=
"float"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--device"
,
parser
.
add_argument
(
type
=
str
,
"--device"
,
type
=
str
,
choices
=
[
"cuda:0"
,
"cuda:1"
],
default
=
"cuda:0"
choices
=
[
"cuda:0"
,
"cuda:1"
],
)
default
=
"cuda:0"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
...
...
benchmarks/kernels/benchmark_shapes.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
WEIGHT_SHAPES
=
{
WEIGHT_SHAPES
=
{
"ideal"
:
[[
4
*
256
*
32
,
256
*
32
]],
"ideal"
:
[[
4
*
256
*
32
,
256
*
32
]],
...
...
benchmarks/kernels/benchmark_w8a8_block_fp8.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from sglang quantization/tuning_block_wise_kernel.py
# Adapted from sglang quantization/tuning_block_wise_kernel.py
import
argparse
import
argparse
...
@@ -14,14 +15,16 @@ import tqdm
...
@@ -14,14 +15,16 @@ import tqdm
import
triton
import
triton
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
_w8a8_block_fp8_matmul
)
_w8a8_block_fp8_matmul
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
assert
current_platform
.
is_cuda
(
assert
current_platform
.
is_cuda
(),
(
),
"Only support tune w8a8 block fp8 kernel on CUDA device."
"Only support tune w8a8 block fp8 kernel on CUDA device."
)
DTYPE_MAP
=
{
DTYPE_MAP
=
{
"float32"
:
torch
.
float32
,
"float32"
:
torch
.
float32
,
...
@@ -40,7 +43,7 @@ def w8a8_block_matmul(
...
@@ -40,7 +43,7 @@ def w8a8_block_matmul(
config
:
dict
[
str
,
Any
],
config
:
dict
[
str
,
Any
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""This function performs matrix multiplication with
"""This function performs matrix multiplication with
block-wise quantization.
block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
...
@@ -51,7 +54,7 @@ def w8a8_block_matmul(
...
@@ -51,7 +54,7 @@ def w8a8_block_matmul(
B: The input tensor, e.g., weight.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization.
block_size: The block size for per-block quantization.
It should be 2-dim, e.g., [128, 128].
It should be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
output_dytpe: The dtype of the returned tensor.
...
@@ -71,18 +74,18 @@ def w8a8_block_matmul(
...
@@ -71,18 +74,18 @@ def w8a8_block_matmul(
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,
)
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
def
grid
(
META
):
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
return
(
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
if
A
.
dtype
==
torch
.
float8_e4m3fn
:
if
A
.
dtype
==
torch
.
float8_e4m3fn
:
kernel
=
_w8a8_block_fp8_matmul
kernel
=
_w8a8_block_fp8_matmul
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
"Currently, only support tune w8a8 block fp8 kernel."
)
"Currently, only support tune w8a8 block fp8 kernel."
)
kernel
[
grid
](
kernel
[
grid
](
A
,
A
,
...
@@ -119,14 +122,16 @@ def get_configs_compute_bound():
...
@@ -119,14 +122,16 @@ def get_configs_compute_bound():
for
block_n
in
[
32
,
64
,
128
,
256
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
for
num_warps
in
[
4
,
8
]:
for
num_warps
in
[
4
,
8
]:
for
group_size
in
[
1
,
16
,
32
,
64
]:
for
group_size
in
[
1
,
16
,
32
,
64
]:
configs
.
append
({
configs
.
append
(
"BLOCK_SIZE_M"
:
block_m
,
{
"BLOCK_SIZE_N"
:
block_n
,
"BLOCK_SIZE_M"
:
block_m
,
"BLOCK_SIZE_K"
:
block_k
,
"BLOCK_SIZE_N"
:
block_n
,
"GROUP_SIZE_M"
:
group_size
,
"BLOCK_SIZE_K"
:
block_k
,
"num_warps"
:
num_warps
,
"GROUP_SIZE_M"
:
group_size
,
"num_stages"
:
num_stages
,
"num_warps"
:
num_warps
,
})
"num_stages"
:
num_stages
,
}
)
return
configs
return
configs
...
@@ -165,15 +170,9 @@ def get_weight_shapes(tp_size):
...
@@ -165,15 +170,9 @@ def get_weight_shapes(tp_size):
return
weight_shapes
return
weight_shapes
def
benchmark_config
(
A
,
def
benchmark_config
(
B
,
A
,
B
,
As
,
Bs
,
block_size
,
config
,
out_dtype
=
torch
.
float16
,
num_iters
=
10
As
,
):
Bs
,
block_size
,
config
,
out_dtype
=
torch
.
float16
,
num_iters
=
10
):
def
run
():
def
run
():
w8a8_block_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
config
,
out_dtype
)
w8a8_block_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
config
,
out_dtype
)
...
@@ -206,26 +205,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
...
@@ -206,26 +205,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
A_fp32
=
(
A_fp32
=
(
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
fp8_max
)
)
A
=
A_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
A
=
A_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
B_fp32
=
(
B_fp32
=
(
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
fp8_max
)
)
B
=
B_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
B
=
B_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
"Currently, only support tune w8a8 block fp8 kernel."
)
"Currently, only support tune w8a8 block fp8 kernel."
)
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
As
=
torch
.
rand
(
M
,
k_tiles
,
dtype
=
torch
.
float32
,
As
=
torch
.
rand
(
M
,
k_tiles
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
factor_for_scale
device
=
"cuda"
)
*
factor_for_scale
Bs
=
(
Bs
=
(
torch
.
rand
(
n_tiles
,
k_tiles
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
torch
.
rand
(
n_tiles
,
k_tiles
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
factor_for_scale
)
*
factor_for_scale
)
best_config
=
None
best_config
=
None
best_time
=
float
(
"inf"
)
best_time
=
float
(
"inf"
)
...
@@ -267,7 +266,8 @@ def save_configs(
...
@@ -267,7 +266,8 @@ def save_configs(
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
json_file_name
=
(
json_file_name
=
(
f
"N=
{
N
}
,K=
{
K
}
,device_name=
{
device_name
}
,dtype=
{
input_type
}
_w8a8,"
f
"N=
{
N
}
,K=
{
K
}
,device_name=
{
device_name
}
,dtype=
{
input_type
}
_w8a8,"
f
"block_shape=[
{
block_n
}
,
{
block_k
}
].json"
)
f
"block_shape=[
{
block_n
}
,
{
block_k
}
].json"
)
config_file_path
=
os
.
path
.
join
(
save_path
,
json_file_name
)
config_file_path
=
os
.
path
.
join
(
save_path
,
json_file_name
)
print
(
f
"Writing best config to
{
config_file_path
}
..."
)
print
(
f
"Writing best config to
{
config_file_path
}
..."
)
...
@@ -295,8 +295,7 @@ def tune_on_gpu(args_dict):
...
@@ -295,8 +295,7 @@ def tune_on_gpu(args_dict):
search_space
=
get_configs_compute_bound
()
search_space
=
get_configs_compute_bound
()
search_space
=
[
search_space
=
[
config
for
config
in
search_space
config
for
config
in
search_space
if
block_k
%
config
[
"BLOCK_SIZE_K"
]
==
0
if
block_k
%
config
[
"BLOCK_SIZE_K"
]
==
0
]
]
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -312,15 +311,11 @@ def tune_on_gpu(args_dict):
...
@@ -312,15 +311,11 @@ def tune_on_gpu(args_dict):
out_dtype
,
out_dtype
,
search_space
,
search_space
,
input_type
,
input_type
,
)
for
batch_size
in
tqdm
(
batch_sizes
,
)
desc
=
f
"GPU
{
gpu_id
}
- Batch sizes"
)
for
batch_size
in
tqdm
(
batch_sizes
,
desc
=
f
"GPU
{
gpu_id
}
- Batch sizes"
)
]
]
best_configs
=
{
best_configs
=
{
M
:
config
for
M
,
config
in
zip
(
batch_sizes
,
benchmark_results
)}
M
:
config
save_configs
(
N
,
K
,
block_n
,
block_k
,
best_configs
,
save_path
,
input_type
)
for
M
,
config
in
zip
(
batch_sizes
,
benchmark_results
)
}
save_configs
(
N
,
K
,
block_n
,
block_k
,
best_configs
,
save_path
,
input_type
)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
f
"Tuning on GPU
{
gpu_id
}
took
{
end
-
start
:.
2
f
}
seconds"
)
print
(
f
"Tuning on GPU
{
gpu_id
}
took
{
end
-
start
:.
2
f
}
seconds"
)
...
@@ -376,13 +371,14 @@ def main(args):
...
@@ -376,13 +371,14 @@ def main(args):
process_args
=
[]
process_args
=
[]
for
gpu_id
in
range
(
num_gpus
):
for
gpu_id
in
range
(
num_gpus
):
process_args
.
append
({
process_args
.
append
(
"gpu_id"
:
gpu_id
,
{
"batch_sizes"
:
batches_per_gpu
[
gpu_id
],
"gpu_id"
:
gpu_id
,
"weight_shapes"
:
"batch_sizes"
:
batches_per_gpu
[
gpu_id
],
weight_shapes
,
# Each GPU processes all weight shapes
"weight_shapes"
:
weight_shapes
,
# Each GPU processes all weight shapes
"args"
:
args
,
"args"
:
args
,
})
}
)
ctx
=
mp
.
get_context
(
"spawn"
)
ctx
=
mp
.
get_context
(
"spawn"
)
with
ctx
.
Pool
(
num_gpus
)
as
pool
:
with
ctx
.
Pool
(
num_gpus
)
as
pool
:
...
@@ -398,13 +394,11 @@ Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
...
@@ -398,13 +394,11 @@ Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
Then copy to model_executor/layers/quantization/utils/configs
Then copy to model_executor/layers/quantization/utils/configs
"""
,
"""
,
formatter_class
=
argparse
.
RawTextHelpFormatter
)
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--input-type"
,
parser
.
add_argument
(
"--input-type"
,
type
=
str
,
choices
=
[
"fp8"
],
default
=
"fp8"
)
type
=
str
,
choices
=
[
"fp8"
],
default
=
"fp8"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--out-dtype"
,
"--out-dtype"
,
type
=
str
,
type
=
str
,
...
...
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# fmt: off
# fmt: off
# ruff: noqa: E501
# ruff: noqa: E501
import
time
import
time
...
@@ -6,13 +7,15 @@ import time
...
@@ -6,13 +7,15 @@ import time
# Import DeepGEMM functions
# Import DeepGEMM functions
import
deep_gemm
import
deep_gemm
import
torch
import
torch
import
triton
from
deep_gemm
import
calc_diff
,
ceil_div
,
get_col_major_tma_aligned_tensor
from
deep_gemm
import
calc_diff
,
ceil_div
,
get_col_major_tma_aligned_tensor
# Import vLLM functions
# Import vLLM functions
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
)
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
,
)
from
vllm.triton_utils
import
triton
# Copied from
# Copied from
...
...
benchmarks/kernels/graph_machete_bench.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
math
import
pickle
import
pickle
import
re
from
collections
import
defaultdict
from
collections
import
defaultdict
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
pandas
as
pd
import
pandas
as
pd
import
regex
as
re
import
seaborn
as
sns
import
seaborn
as
sns
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
...
@@ -14,13 +15,14 @@ from vllm.utils import FlexibleArgumentParser
...
@@ -14,13 +15,14 @@ from vllm.utils import FlexibleArgumentParser
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
description
=
'Benchmark the latency of processing a single batch of '
description
=
"Benchmark the latency of processing a single batch of "
'requests till completion.'
)
"requests till completion."
parser
.
add_argument
(
'filename'
,
type
=
str
)
)
parser
.
add_argument
(
"filename"
,
type
=
str
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
with
open
(
args
.
filename
,
'
rb
'
)
as
f
:
with
open
(
args
.
filename
,
"
rb
"
)
as
f
:
data
=
pickle
.
load
(
f
)
data
=
pickle
.
load
(
f
)
raw_results
:
list
[
TMeasurement
]
=
data
[
"results"
]
raw_results
:
list
[
TMeasurement
]
=
data
[
"results"
]
...
@@ -38,11 +40,7 @@ if __name__ == "__main__":
...
@@ -38,11 +40,7 @@ if __name__ == "__main__":
raise
Exception
(
"MKN not found"
)
raise
Exception
(
"MKN not found"
)
kernel
=
v
.
task_spec
.
description
kernel
=
v
.
task_spec
.
description
results
[
KN
].
append
({
results
[
KN
].
append
({
"kernel"
:
kernel
,
"batch_size"
:
M
,
"median"
:
v
.
median
})
"kernel"
:
kernel
,
"batch_size"
:
M
,
"median"
:
v
.
median
})
rows
=
int
(
math
.
ceil
(
len
(
results
)
/
2
))
rows
=
int
(
math
.
ceil
(
len
(
results
)
/
2
))
fig
,
axs
=
plt
.
subplots
(
rows
,
2
,
figsize
=
(
12
,
5
*
rows
))
fig
,
axs
=
plt
.
subplots
(
rows
,
2
,
figsize
=
(
12
,
5
*
rows
))
...
@@ -50,14 +48,16 @@ if __name__ == "__main__":
...
@@ -50,14 +48,16 @@ if __name__ == "__main__":
for
axs_idx
,
(
shape
,
data
)
in
enumerate
(
results
.
items
()):
for
axs_idx
,
(
shape
,
data
)
in
enumerate
(
results
.
items
()):
plt
.
sca
(
axs
[
axs_idx
])
plt
.
sca
(
axs
[
axs_idx
])
df
=
pd
.
DataFrame
(
data
)
df
=
pd
.
DataFrame
(
data
)
sns
.
lineplot
(
data
=
df
,
sns
.
lineplot
(
x
=
"batch_size"
,
data
=
df
,
y
=
"median"
,
x
=
"batch_size"
,
hue
=
"kernel"
,
y
=
"median"
,
style
=
"kernel"
,
hue
=
"kernel"
,
markers
=
True
,
style
=
"kernel"
,
dashes
=
False
,
markers
=
True
,
palette
=
"Dark2"
)
dashes
=
False
,
palette
=
"Dark2"
,
)
plt
.
title
(
f
"Shape:
{
shape
}
"
)
plt
.
title
(
f
"Shape:
{
shape
}
"
)
plt
.
ylabel
(
"time (median, s)"
)
plt
.
ylabel
(
"time (median, s)"
)
plt
.
tight_layout
()
plt
.
tight_layout
()
...
...
benchmarks/kernels/utils.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
dataclasses
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
...
@@ -23,6 +24,7 @@ class ArgPool:
...
@@ -23,6 +24,7 @@ class ArgPool:
For every invocation during a benchmarking run, it will choose a
For every invocation during a benchmarking run, it will choose a
different value from the list.
different value from the list.
"""
"""
values
:
Iterable
[
Any
]
values
:
Iterable
[
Any
]
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
...
@@ -30,9 +32,7 @@ class ArgPool:
...
@@ -30,9 +32,7 @@ class ArgPool:
class
Bench
:
class
Bench
:
class
ArgsIterator
:
class
ArgsIterator
:
def
__init__
(
self
,
args_list
,
kwargs_list
):
def
__init__
(
self
,
args_list
,
kwargs_list
):
assert
len
(
args_list
)
==
len
(
kwargs_list
)
assert
len
(
args_list
)
==
len
(
kwargs_list
)
self
.
args_list
=
args_list
self
.
args_list
=
args_list
...
@@ -53,10 +53,16 @@ class Bench:
...
@@ -53,10 +53,16 @@ class Bench:
def
n_args
(
self
):
def
n_args
(
self
):
return
self
.
n
return
self
.
n
def
__init__
(
self
,
cuda_graph_params
:
Optional
[
CudaGraphBenchParams
],
def
__init__
(
label
:
str
,
sub_label
:
str
,
description
:
str
,
fn
:
Callable
,
self
,
*
args
,
**
kwargs
):
cuda_graph_params
:
Optional
[
CudaGraphBenchParams
],
label
:
str
,
sub_label
:
str
,
description
:
str
,
fn
:
Callable
,
*
args
,
**
kwargs
,
):
self
.
cuda_graph_params
=
cuda_graph_params
self
.
cuda_graph_params
=
cuda_graph_params
self
.
use_cuda_graph
=
self
.
cuda_graph_params
is
not
None
self
.
use_cuda_graph
=
self
.
cuda_graph_params
is
not
None
self
.
label
=
label
self
.
label
=
label
...
@@ -67,10 +73,8 @@ class Bench:
...
@@ -67,10 +73,8 @@ class Bench:
# Process args
# Process args
self
.
_args
=
args
self
.
_args
=
args
self
.
_kwargs
=
kwargs
self
.
_kwargs
=
kwargs
self
.
args_list
,
self
.
kwargs_list
=
self
.
collapse_argpool
(
self
.
args_list
,
self
.
kwargs_list
=
self
.
collapse_argpool
(
*
args
,
**
kwargs
)
*
args
,
**
kwargs
)
self
.
args_iterator
=
self
.
ArgsIterator
(
self
.
args_list
,
self
.
kwargs_list
)
self
.
args_iterator
=
self
.
ArgsIterator
(
self
.
args_list
,
self
.
kwargs_list
)
# Cudagraph runner
# Cudagraph runner
self
.
g
=
None
self
.
g
=
None
...
@@ -100,16 +104,13 @@ class Bench:
...
@@ -100,16 +104,13 @@ class Bench:
for
i
in
range
(
argpool_size
):
for
i
in
range
(
argpool_size
):
# collapse args; Just pick the ith value
# collapse args; Just pick the ith value
args_list
[
i
]
=
tuple
([
args_list
[
i
]
=
tuple
(
arg
[
i
]
if
isinstance
(
arg
,
ArgPool
)
else
arg
[
arg
[
i
]
if
isinstance
(
arg
,
ArgPool
)
else
arg
for
arg
in
args_list
[
i
]]
for
arg
in
args_list
[
i
]
)
])
# collapse kwargs
# collapse kwargs
kwargs_i
=
kwargs_list
[
i
]
kwargs_i
=
kwargs_list
[
i
]
arg_pool_keys
=
[
arg_pool_keys
=
[
k
for
k
,
v
in
kwargs_i
.
items
()
if
isinstance
(
v
,
ArgPool
)]
k
for
k
,
v
in
kwargs_i
.
items
()
if
isinstance
(
v
,
ArgPool
)
]
for
k
in
arg_pool_keys
:
for
k
in
arg_pool_keys
:
# again just pick the ith value
# again just pick the ith value
kwargs_i
[
k
]
=
kwargs_i
[
k
][
i
]
kwargs_i
[
k
]
=
kwargs_i
[
k
][
i
]
...
@@ -142,7 +143,7 @@ class Bench:
...
@@ -142,7 +143,7 @@ class Bench:
def
run_cudagrah
(
self
)
->
TMeasurement
:
def
run_cudagrah
(
self
)
->
TMeasurement
:
assert
self
.
use_cuda_graph
assert
self
.
use_cuda_graph
globals
=
{
'g'
:
self
.
g
}
globals
=
{
"g"
:
self
.
g
}
return
TBenchmark
.
Timer
(
return
TBenchmark
.
Timer
(
stmt
=
"g.replay()"
,
stmt
=
"g.replay()"
,
...
@@ -162,15 +163,15 @@ class Bench:
...
@@ -162,15 +163,15 @@ class Bench:
has_arg_pool
=
self
.
args_iterator
.
n_args
>
1
has_arg_pool
=
self
.
args_iterator
.
n_args
>
1
if
has_arg_pool
:
if
has_arg_pool
:
setup
=
'''
setup
=
"""
args_iterator.reset()
args_iterator.reset()
args_it = args_iterator.__next__()
args_it = args_iterator.__next__()
'''
"""
stmt
=
'''
stmt
=
"""
args, kwargs = next(args_it)
args, kwargs = next(args_it)
fn(*args, **kwargs)
fn(*args, **kwargs)
'''
"""
globals
=
{
'
fn
'
:
self
.
fn
,
'
args_iterator
'
:
self
.
args_iterator
}
globals
=
{
"
fn
"
:
self
.
fn
,
"
args_iterator
"
:
self
.
args_iterator
}
else
:
else
:
# no arg pool. Just use the args and kwargs directly
# no arg pool. Just use the args and kwargs directly
self
.
args_iterator
.
reset
()
self
.
args_iterator
.
reset
()
...
@@ -178,10 +179,10 @@ class Bench:
...
@@ -178,10 +179,10 @@ class Bench:
args
,
kwargs
=
next
(
args_it
)
args
,
kwargs
=
next
(
args_it
)
setup
=
""
setup
=
""
stmt
=
'''
stmt
=
"""
fn(*args, **kwargs)
fn(*args, **kwargs)
'''
"""
globals
=
{
'
fn
'
:
self
.
fn
,
'
args
'
:
args
,
'
kwargs
'
:
kwargs
}
globals
=
{
"
fn
"
:
self
.
fn
,
"
args
"
:
args
,
"
kwargs
"
:
kwargs
}
return
TBenchmark
.
Timer
(
return
TBenchmark
.
Timer
(
stmt
=
stmt
,
stmt
=
stmt
,
...
...
benchmarks/kernels/weight_shapes.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Weight Shapes are in the format
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# ([K, N], TP_SPLIT_DIM)
...
@@ -48,4 +49,50 @@ WEIGHT_SHAPES = {
...
@@ -48,4 +49,50 @@ WEIGHT_SHAPES = {
([
16384
,
106496
],
1
),
([
16384
,
106496
],
1
),
([
53248
,
16384
],
0
),
([
53248
,
16384
],
0
),
],
],
"meta-llama/Llama-3.1-8B-Instruct"
:
[
([
4096
,
6144
],
1
),
([
4096
,
4096
],
0
),
([
4096
,
28672
],
1
),
([
14336
,
4096
],
0
),
],
"meta-llama/Llama-3.3-70B-Instruct"
:
[
([
8192
,
10240
],
1
),
([
8192
,
8192
],
0
),
([
8192
,
57344
],
1
),
([
28672
,
8192
],
0
),
],
"mistralai/Mistral-Large-Instruct-2407"
:
[
([
12288
,
14336
],
1
),
([
12288
,
12288
],
0
),
([
12288
,
57344
],
1
),
([
28672
,
12288
],
0
),
],
"Qwen/Qwen2.5-7B-Instruct"
:
[
([
3584
,
4608
],
1
),
([
3584
,
3584
],
0
),
([
3584
,
37888
],
1
),
([
18944
,
3584
],
0
),
],
"Qwen/Qwen2.5-32B-Instruct"
:
[
([
5120
,
7168
],
1
),
([
5120
,
5120
],
0
),
([
5120
,
55296
],
1
),
([
27648
,
5120
],
0
),
],
"Qwen/Qwen2.5-72B-Instruct"
:
[
([
8192
,
10240
],
1
),
([
8192
,
8192
],
0
),
([
8192
,
59136
],
1
),
([
29568
,
8192
],
0
),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
:
[
([
2048
,
3072
],
1
),
([
2048
,
4096
],
1
),
([
2048
,
2048
],
0
),
([
2048
,
576
],
0
),
([
2048
,
21888
],
1
),
([
10944
,
2048
],
0
),
([
2048
,
2816
],
1
),
([
1408
,
2048
],
0
),
],
}
}
benchmarks/overheads/benchmark_hashing.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
cProfile
import
cProfile
import
pstats
import
pstats
...
@@ -7,9 +8,8 @@ from vllm import LLM, SamplingParams
...
@@ -7,9 +8,8 @@ from vllm import LLM, SamplingParams
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
# A very long prompt, total number of tokens is about 15k.
# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT
=
[
"You are an expert in large language models, aren't you?"
LONG_PROMPT
=
[
"You are an expert in large language models, aren't you?"
]
*
1000
]
*
1000
LONG_PROMPT
=
" "
.
join
(
LONG_PROMPT
)
LONG_PROMPT
=
' '
.
join
(
LONG_PROMPT
)
def
main
(
args
):
def
main
(
args
):
...
@@ -30,32 +30,35 @@ def main(args):
...
@@ -30,32 +30,35 @@ def main(args):
print
(
"------start generating------"
)
print
(
"------start generating------"
)
for
i
in
range
(
3
):
for
i
in
range
(
3
):
profiler
.
runctx
(
'llm.generate(LONG_PROMPT, sampling_params)'
,
profiler
.
runctx
(
globals
(),
locals
())
"llm.generate(LONG_PROMPT, sampling_params)"
,
globals
(),
locals
()
)
# analyze the runtime of hashing function
# analyze the runtime of hashing function
stats
=
pstats
.
Stats
(
profiler
)
stats
=
pstats
.
Stats
(
profiler
)
stats
.
sort_stats
(
'
cumulative
'
)
stats
.
sort_stats
(
"
cumulative
"
)
total_time
=
0
total_time
=
0
total_calls
=
0
total_calls
=
0
for
func
in
stats
.
stats
:
for
func
in
stats
.
stats
:
if
'
hash_of_block
'
in
func
[
2
]:
if
"
hash_of_block
"
in
func
[
2
]:
total_time
=
stats
.
stats
[
func
][
3
]
total_time
=
stats
.
stats
[
func
][
3
]
total_calls
=
stats
.
stats
[
func
][
0
]
total_calls
=
stats
.
stats
[
func
][
0
]
percentage
=
(
total_time
/
stats
.
total_tt
)
*
100
percentage
=
(
total_time
/
stats
.
total_tt
)
*
100
print
(
f
"Hashing took
{
total_time
:.
2
f
}
seconds,"
print
(
f
"
{
percentage
:.
2
f
}
% of the total runtime."
)
f
"Hashing took
{
total_time
:.
2
f
}
seconds,
{
percentage
:.
2
f
}
% of the total runtime."
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
description
=
'Benchmark the performance of hashing function in'
description
=
"Benchmark the performance of hashing function in"
'automatic prefix caching.'
)
"automatic prefix caching."
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'lmsys/longchat-7b-16k'
)
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"lmsys/longchat-7b-16k"
)
parser
.
add_argument
(
'--output-len'
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--enable-prefix-caching'
,
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
10
)
action
=
'store_true'
,
parser
.
add_argument
(
help
=
'enable prefix caching'
)
"--enable-prefix-caching"
,
action
=
"store_true"
,
help
=
"enable prefix caching"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
benchmarks/pyproject.toml
0 → 100644
View file @
4c676e3d
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.ruff]
line-length
=
88
[tool.ruff.lint.per-file-ignores]
"vllm/third_party/**"
=
["ALL"]
"vllm/version.py"
=
["F401"]
"vllm/_version.py"
=
["ALL"]
[tool.ruff.lint]
select
=
[
# pycodestyle
"E"
,
# Pyflakes
"F"
,
# pyupgrade
"UP"
,
# flake8-bugbear
"B"
,
# flake8-simplify
"SIM"
,
# isort
"I"
,
# flake8-logging-format
"G"
,
]
ignore
=
[
# star imports
"F405"
,
"F403"
,
# lambda expression assignment
"E731"
,
# Loop control variable not used within loop body
"B007"
,
# f-string format
"UP032"
,
# Can remove once 3.10+ is the minimum Python version
"UP007"
,
]
[tool.ruff.lint.isort]
known-first-party
=
["vllm"]
[tool.ruff.format]
docstring-code-format
=
true
\ No newline at end of file
benchmarks/run_structured_output_benchmark.sh
View file @
4c676e3d
#!/bin/bash
#!/bin/bash
# Define the model to use
# default values
MODEL
=
${
1
:-
"Qwen/Qwen2.5-7B-Instruct"
}
MODEL
=
${
MODEL
:-
"Qwen/Qwen2.5-7B-Instruct"
}
BACKEND
=
${
BACKEND
:-
"vllm"
}
# Define the backend to use
DATASET
=
${
DATASET
:-
"xgrammar_bench"
}
BACKEND
=
${
2
:-
"vllm"
}
# Define the dataset to use
DATASET
=
${
3
:-
"xgrammar_bench"
}
# Define the guided decoding backend
GUIDED_BACKEND
=
${
4
:-
"xgrammar"
}
SCRIPT_DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
SCRIPT_DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
OUTPUT_DIR
=
${
5
:-
"
$SCRIPT_DIR
/structured_output_benchmark_results"
}
OUTPUT_DIR
=
${
OUTPUT_DIR
:-
"
$SCRIPT_DIR
/structured_output_benchmark_results"
}
PORT
=
${
PORT
:-
8000
}
GUIDED_RATIO
=
${
6
:-
0
.5
}
STRUCTURED_OUTPUT_RATIO
=
${
STRUCTURED_OUTPUT_RATIO
:-
1
}
TOTAL_SECONDS
=
${
TOTAL_SECONDS
:-
90
}
MAX_NEW_TOKENS
=
${
MAX_NEW_TOKENS
:-
300
}
TOKENIZER_MODE
=
${
TOKENIZER_MODE
:-
"auto"
}
usage
()
{
echo
"Usage:
$0
[options]"
echo
"Options:"
echo
" --model MODEL Model to benchmark (default:
$MODEL
)"
echo
" --backend BACKEND Backend to use (default:
$BACKEND
)"
echo
" --dataset DATASET Dataset to use (default:
$DATASET
)"
echo
" --max-new-tokens N Maximum number of tokens to generate (default:
$MAX_NEW_TOKENS
)"
echo
" --output-dir DIR Output directory for results (default:
$OUTPUT_DIR
)"
echo
" --port PORT Port to use (default:
$PORT
)"
echo
" --structured-output-ratio N Ratio of structured outputs (default:
$STRUCTURED_OUTPUT_RATIO
)"
echo
" --tokenizer-mode MODE Tokenizer mode to use (default:
$TOKENIZER_MODE
)"
echo
" --total-seconds N Total seconds to run the benchmark (default:
$TOTAL_SECONDS
)"
echo
" -h, --help Show this help message and exit"
exit
0
}
# parse command line arguments
while
[[
$#
-gt
0
]]
;
do
case
$1
in
--model
)
MODEL
=
"
$2
"
shift
2
;;
--backend
)
BACKEND
=
"
$2
"
shift
2
;;
--dataset
)
DATASET
=
"
$2
"
shift
2
;;
--max-new-tokens
)
MAX_NEW_TOKENS
=
"
$2
"
shift
2
;;
--output-dir
)
OUTPUT_DIR
=
"
$2
"
shift
2
;;
--port
)
PORT
=
"
$2
"
shift
2
;;
--structured-output-ratio
)
STRUCTURED_OUTPUT_RATIO
=
"
$2
"
shift
2
;;
--tokenizer-mode
)
TOKENIZER_MODE
=
"
$2
"
shift
2
;;
--total-seconds
)
TOTAL_SECONDS
=
"
$2
"
shift
2
;;
-h
|
--help
)
usage
;;
*
)
echo
"Unknown argument:
$1
\n
"
usage
;;
esac
done
# Create output directory if it doesn't exist
# Create output directory if it doesn't exist
mkdir
-p
"
$OUTPUT_DIR
"
mkdir
-p
"
$OUTPUT_DIR
"
# Define QPS values to test
# Define QPS values to test
QPS_VALUES
=(
70 60 50
25 20 15 10
)
QPS_VALUES
=(
25 20 15 10
5 1
)
# Common parameters
# Common parameters
COMMON_PARAMS
=
"--backend
$BACKEND
\
COMMON_PARAMS
=
"--backend
$BACKEND
\
--model
$MODEL
\
--model
$MODEL
\
--dataset
$DATASET
\
--dataset
$DATASET
\
--structured-output-backend
$GUIDED_BACKEND
\
--structured-output-ratio
$STRUCTURED_OUTPUT_RATIO
\
--structured-output-ratio
$GUIDED_RATIO
\
--save-results
\
--save-results
\
--result-dir
$OUTPUT_DIR
"
--result-dir
$OUTPUT_DIR
\
--output-len
$MAX_NEW_TOKENS
\
--port
$PORT
\
--tokenizer-mode
$TOKENIZER_MODE
"
echo
"Starting structured output benchmark with model:
$MODEL
"
echo
"Starting structured output benchmark with model:
$MODEL
"
echo
"Backend:
$BACKEND
"
echo
"Backend:
$BACKEND
"
echo
"Dataset:
$DATASET
"
echo
"Dataset:
$DATASET
"
echo
"Structured output backend:
$GUIDED_BACKEND
"
echo
"Results will be saved to:
$OUTPUT_DIR
"
echo
"Results will be saved to:
$OUTPUT_DIR
"
echo
"----------------------------------------"
echo
"----------------------------------------"
...
@@ -48,14 +109,17 @@ for qps in "${QPS_VALUES[@]}"; do
...
@@ -48,14 +109,17 @@ for qps in "${QPS_VALUES[@]}"; do
GIT_BRANCH
=
$(
git rev-parse
--abbrev-ref
HEAD 2>/dev/null
||
echo
"unknown"
)
GIT_BRANCH
=
$(
git rev-parse
--abbrev-ref
HEAD 2>/dev/null
||
echo
"unknown"
)
# Construct filename for this run
# Construct filename for this run
FILENAME
=
"
${
GUIDED_BACKEND
}
_
${
BACKEND
}
_
${
qps
}
qps_
$(
basename
$MODEL
)
_
${
DATASET
}
_
${
GIT_HASH
}
.json"
FILENAME
=
"
${
BACKEND
}
_
${
qps
}
qps_
$(
basename
$MODEL
)
_
${
DATASET
}
_
${
GIT_HASH
}
.json"
NUM_PROMPTS
=
$(
echo
"
$TOTAL_SECONDS
*
$qps
"
| bc
)
NUM_PROMPTS
=
${
NUM_PROMPTS
%.*
}
# Remove fractional part
echo
"Running benchmark with
$NUM_PROMPTS
prompts"
# Run the benchmark
# Run the benchmark
python
"
$SCRIPT_DIR
/benchmark_serving_structured_output.py"
$COMMON_PARAMS
\
python
"
$SCRIPT_DIR
/benchmark_serving_structured_output.py"
$COMMON_PARAMS
\
--request-rate
$qps
\
--request-rate
$qps
\
--result-filename
"
$FILENAME
"
\
--result-filename
"
$FILENAME
"
\
--tokenizer-mode
${
TOKENIZER_MODE
:-
"auto"
}
\
--num-prompts
$NUM_PROMPTS
--port
${
PORT
:-
8000
}
echo
"Completed benchmark with QPS:
$qps
"
echo
"Completed benchmark with QPS:
$qps
"
echo
"----------------------------------------"
echo
"----------------------------------------"
...
...
cmake/cpu_extension.cmake
View file @
4c676e3d
...
@@ -75,6 +75,7 @@ if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
...
@@ -75,6 +75,7 @@ if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
else
()
else
()
find_isa
(
${
CPUINFO
}
"avx2"
AVX2_FOUND
)
find_isa
(
${
CPUINFO
}
"avx2"
AVX2_FOUND
)
find_isa
(
${
CPUINFO
}
"avx512f"
AVX512_FOUND
)
find_isa
(
${
CPUINFO
}
"avx512f"
AVX512_FOUND
)
find_isa
(
${
CPUINFO
}
"Power11"
POWER11_FOUND
)
find_isa
(
${
CPUINFO
}
"POWER10"
POWER10_FOUND
)
find_isa
(
${
CPUINFO
}
"POWER10"
POWER10_FOUND
)
find_isa
(
${
CPUINFO
}
"POWER9"
POWER9_FOUND
)
find_isa
(
${
CPUINFO
}
"POWER9"
POWER9_FOUND
)
find_isa
(
${
CPUINFO
}
"asimd"
ASIMD_FOUND
)
# Check for ARM NEON support
find_isa
(
${
CPUINFO
}
"asimd"
ASIMD_FOUND
)
# Check for ARM NEON support
...
@@ -106,13 +107,19 @@ elseif (AVX2_FOUND)
...
@@ -106,13 +107,19 @@ elseif (AVX2_FOUND)
list
(
APPEND CXX_COMPILE_FLAGS
"-mavx2"
)
list
(
APPEND CXX_COMPILE_FLAGS
"-mavx2"
)
message
(
WARNING
"vLLM CPU backend using AVX2 ISA"
)
message
(
WARNING
"vLLM CPU backend using AVX2 ISA"
)
elseif
(
POWER9_FOUND OR POWER10_FOUND
)
elseif
(
POWER9_FOUND OR POWER10_FOUND
OR POWER11_FOUND
)
message
(
STATUS
"PowerPC detected"
)
message
(
STATUS
"PowerPC detected"
)
# Check for PowerPC VSX support
if
(
POWER9_FOUND
)
list
(
APPEND CXX_COMPILE_FLAGS
list
(
APPEND CXX_COMPILE_FLAGS
"-mvsx"
"-mvsx"
"-mcpu=native"
"-mcpu=power9"
"-mtune=native"
)
"-mtune=power9"
)
elseif
(
POWER10_FOUND OR POWER11_FOUND
)
list
(
APPEND CXX_COMPILE_FLAGS
"-mvsx"
"-mcpu=power10"
"-mtune=power10"
)
endif
()
elseif
(
ASIMD_FOUND
)
elseif
(
ASIMD_FOUND
)
message
(
STATUS
"ARMv8 or later architecture detected"
)
message
(
STATUS
"ARMv8 or later architecture detected"
)
...
@@ -167,6 +174,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
...
@@ -167,6 +174,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
FetchContent_MakeAvailable
(
oneDNN
)
FetchContent_MakeAvailable
(
oneDNN
)
list
(
APPEND LIBS dnnl
)
elseif
(
POWER10_FOUND
)
FetchContent_Declare
(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.7.2
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
set
(
ONEDNN_LIBRARY_TYPE
"STATIC"
)
set
(
ONEDNN_BUILD_DOC
"OFF"
)
set
(
ONEDNN_BUILD_EXAMPLES
"OFF"
)
set
(
ONEDNN_BUILD_TESTS
"OFF"
)
set
(
ONEDNN_ENABLE_WORKLOAD
"INFERENCE"
)
set
(
ONEDNN_ENABLE_PRIMITIVE
"MATMUL;REORDER"
)
set
(
ONEDNN_BUILD_GRAPH
"OFF"
)
set
(
ONEDNN_ENABLE_JIT_PROFILING
"OFF"
)
set
(
ONEDNN_ENABLE_ITT_TASKS
"OFF"
)
set
(
ONEDNN_ENABLE_MAX_CPU_ISA
"OFF"
)
set
(
ONEDNN_ENABLE_CPU_ISA_HINTS
"OFF"
)
set
(
CMAKE_POLICY_DEFAULT_CMP0077 NEW
)
set
(
DNNL_CPU_RUNTIME
"OMP"
)
FetchContent_MakeAvailable
(
oneDNN
)
list
(
APPEND LIBS dnnl
)
list
(
APPEND LIBS dnnl
)
endif
()
endif
()
...
@@ -197,6 +231,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
...
@@ -197,6 +231,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
"csrc/cpu/quant.cpp"
"csrc/cpu/quant.cpp"
"csrc/cpu/shm.cpp"
"csrc/cpu/shm.cpp"
${
VLLM_EXT_SRC
}
)
${
VLLM_EXT_SRC
}
)
elseif
(
POWER10_FOUND
)
set
(
VLLM_EXT_SRC
"csrc/cpu/quant.cpp"
${
VLLM_EXT_SRC
}
)
endif
()
endif
()
#
#
...
@@ -214,4 +252,4 @@ define_gpu_extension_target(
...
@@ -214,4 +252,4 @@ define_gpu_extension_target(
WITH_SOABI
WITH_SOABI
)
)
message
(
STATUS
"Enabling C extension."
)
message
(
STATUS
"Enabling C extension."
)
\ No newline at end of file
cmake/external_projects/vllm_flash_attn.cmake
View file @
4c676e3d
...
@@ -46,22 +46,38 @@ else()
...
@@ -46,22 +46,38 @@ else()
endif
()
endif
()
# Ensure the vllm/vllm_flash_attn directory exists before installation
install
(
CODE
"file(MAKE_DIRECTORY
\"\$
{CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn
\"
)"
ALL_COMPONENTS
)
# Make sure vllm-flash-attn install rules are nested under vllm/
# This is here to support installing all components under the same prefix with cmake --install.
# setup.py installs every component separately but uses the same prefix for all.
# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3,
# and these statements don't hurt when installing neither component.
install
(
CODE
"set(CMAKE_INSTALL_LOCAL_ONLY FALSE)"
ALL_COMPONENTS
)
install
(
CODE
"set(OLD_CMAKE_INSTALL_PREFIX
\"\$
{CMAKE_INSTALL_PREFIX}
\"
)"
ALL_COMPONENTS
)
install
(
CODE
"set(CMAKE_INSTALL_PREFIX
\"\$
{CMAKE_INSTALL_PREFIX}/vllm/
\"
)"
ALL_COMPONENTS
)
# Fetch the vllm-flash-attn library
# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable
(
vllm-flash-attn
)
FetchContent_MakeAvailable
(
vllm-flash-attn
)
message
(
STATUS
"vllm-flash-attn is available at
${
vllm-flash-attn_SOURCE_DIR
}
"
)
message
(
STATUS
"vllm-flash-attn is available at
${
vllm-flash-attn_SOURCE_DIR
}
"
)
# Restore the install prefix
install
(
CODE
"set(CMAKE_INSTALL_PREFIX
\"\$
{OLD_CMAKE_INSTALL_PREFIX}
\"
)"
ALL_COMPONENTS
)
install
(
CODE
"set(CMAKE_INSTALL_LOCAL_ONLY TRUE)"
ALL_COMPONENTS
)
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
# case only one is built, in the case both are built redundant work is done)
# case only one is built, in the case both are built redundant work is done)
install
(
install
(
DIRECTORY
${
vllm-flash-attn_SOURCE_DIR
}
/vllm_flash_attn/
DIRECTORY
${
vllm-flash-attn_SOURCE_DIR
}
/vllm_flash_attn/
DESTINATION vllm_flash_attn
DESTINATION
vllm/
vllm_flash_attn
COMPONENT _vllm_fa2_C
COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN
"*.py"
FILES_MATCHING PATTERN
"*.py"
)
)
install
(
install
(
DIRECTORY
${
vllm-flash-attn_SOURCE_DIR
}
/vllm_flash_attn/
DIRECTORY
${
vllm-flash-attn_SOURCE_DIR
}
/vllm_flash_attn/
DESTINATION vllm_flash_attn
DESTINATION
vllm/
vllm_flash_attn
COMPONENT _vllm_fa3_C
COMPONENT _vllm_fa3_C
FILES_MATCHING PATTERN
"*.py"
FILES_MATCHING PATTERN
"*.py"
)
)
cmake/hipify.py
View file @
4c676e3d
#!/usr/bin/env python3
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
#
# A command line tool for running pytorch's hipify preprocessor on CUDA
# A command line tool for running pytorch's hipify preprocessor on CUDA
...
...
cmake/utils.cmake
View file @
4c676e3d
...
@@ -76,7 +76,7 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
...
@@ -76,7 +76,7 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
set
(
CSRC_BUILD_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
/csrc
)
set
(
CSRC_BUILD_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
/csrc
)
add_custom_target
(
add_custom_target
(
hipify
${
NAME
}
hipify
${
NAME
}
COMMAND
${
CMAKE_SOURCE_DIR
}
/cmake/hipify.py -p
${
CMAKE_SOURCE_DIR
}
/csrc -o
${
CSRC_BUILD_DIR
}
${
SRCS
}
COMMAND
${
Python_EXECUTABLE
}
${
CMAKE_SOURCE_DIR
}
/cmake/hipify.py -p
${
CMAKE_SOURCE_DIR
}
/csrc -o
${
CSRC_BUILD_DIR
}
${
SRCS
}
DEPENDS
${
CMAKE_SOURCE_DIR
}
/cmake/hipify.py
${
SRCS
}
DEPENDS
${
CMAKE_SOURCE_DIR
}
/cmake/hipify.py
${
SRCS
}
BYPRODUCTS
${
HIP_SRCS
}
BYPRODUCTS
${
HIP_SRCS
}
COMMENT
"Running hipify on
${
NAME
}
extension source files."
)
COMMENT
"Running hipify on
${
NAME
}
extension source files."
)
...
@@ -233,11 +233,26 @@ macro(set_gencode_flags_for_srcs)
...
@@ -233,11 +233,26 @@ macro(set_gencode_flags_for_srcs)
"
${
multiValueArgs
}
"
${
ARGN
}
)
"
${
multiValueArgs
}
"
${
ARGN
}
)
foreach
(
_ARCH
${
arg_CUDA_ARCHS
}
)
foreach
(
_ARCH
${
arg_CUDA_ARCHS
}
)
string
(
REPLACE
"."
""
_ARCH
"
${
_ARCH
}
"
)
# handle +PTX suffix: generate both sm and ptx codes if requested
set_gencode_flag_for_srcs
(
string
(
FIND
"
${
_ARCH
}
"
"+PTX"
_HAS_PTX
)
SRCS
${
arg_SRCS
}
if
(
NOT _HAS_PTX EQUAL -1
)
ARCH
"compute_
${
_ARCH
}
"
string
(
REPLACE
"+PTX"
""
_BASE_ARCH
"
${
_ARCH
}
"
)
CODE
"sm_
${
_ARCH
}
"
)
string
(
REPLACE
"."
""
_STRIPPED_ARCH
"
${
_BASE_ARCH
}
"
)
set_gencode_flag_for_srcs
(
SRCS
${
arg_SRCS
}
ARCH
"compute_
${
_STRIPPED_ARCH
}
"
CODE
"sm_
${
_STRIPPED_ARCH
}
"
)
set_gencode_flag_for_srcs
(
SRCS
${
arg_SRCS
}
ARCH
"compute_
${
_STRIPPED_ARCH
}
"
CODE
"compute_
${
_STRIPPED_ARCH
}
"
)
else
()
string
(
REPLACE
"."
""
_STRIPPED_ARCH
"
${
_ARCH
}
"
)
set_gencode_flag_for_srcs
(
SRCS
${
arg_SRCS
}
ARCH
"compute_
${
_STRIPPED_ARCH
}
"
CODE
"sm_
${
_STRIPPED_ARCH
}
"
)
endif
()
endforeach
()
endforeach
()
if
(
${
arg_BUILD_PTX_FOR_ARCH
}
)
if
(
${
arg_BUILD_PTX_FOR_ARCH
}
)
...
@@ -256,7 +271,10 @@ endmacro()
...
@@ -256,7 +271,10 @@ endmacro()
#
#
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# `<major>.<minor>[letter]` compute the "loose intersection" with the
# `<major>.<minor>[letter]` compute the "loose intersection" with the
# `TGT_CUDA_ARCHS` list of gencodes.
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
# architecture in `SRC_CUDA_ARCHS`.
# The loose intersection is defined as:
# The loose intersection is defined as:
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
# where `<=` is the version comparison operator.
# where `<=` is the version comparison operator.
...
@@ -273,44 +291,63 @@ endmacro()
...
@@ -273,44 +291,63 @@ endmacro()
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
#
#
# Example With PTX:
# SRC_CUDA_ARCHS="8.0+PTX"
# TGT_CUDA_ARCHS="9.0"
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0+PTX"
#
function
(
cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS
)
function
(
cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS
)
list
(
REMOVE_DUPLICATES SRC_CUDA_ARCHS
)
set
(
_SRC_CUDA_ARCHS
"
${
SRC_CUDA_ARCHS
}
"
)
set
(
TGT_CUDA_ARCHS_
${
TGT_CUDA_ARCHS
}
)
set
(
_TGT_CUDA_ARCHS
${
TGT_CUDA_ARCHS
}
)
# handle +PTX suffix: separate base arch for matching, record PTX requests
set
(
_PTX_ARCHS
)
foreach
(
_arch
${
_SRC_CUDA_ARCHS
}
)
if
(
_arch MATCHES
"
\\
+PTX$"
)
string
(
REPLACE
"+PTX"
""
_base
"
${
_arch
}
"
)
list
(
APPEND _PTX_ARCHS
"
${
_base
}
"
)
list
(
REMOVE_ITEM _SRC_CUDA_ARCHS
"
${
_arch
}
"
)
list
(
APPEND _SRC_CUDA_ARCHS
"
${
_base
}
"
)
endif
()
endforeach
()
list
(
REMOVE_DUPLICATES _PTX_ARCHS
)
list
(
REMOVE_DUPLICATES _SRC_CUDA_ARCHS
)
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
set
(
_CUDA_ARCHS
)
set
(
_CUDA_ARCHS
)
if
(
"9.0a"
IN_LIST SRC_CUDA_ARCHS
)
if
(
"9.0a"
IN_LIST
_
SRC_CUDA_ARCHS
)
list
(
REMOVE_ITEM SRC_CUDA_ARCHS
"9.0a"
)
list
(
REMOVE_ITEM
_
SRC_CUDA_ARCHS
"9.0a"
)
if
(
"9.0"
IN_LIST TGT_CUDA_ARCHS
_
)
if
(
"9.0"
IN_LIST TGT_CUDA_ARCHS
)
list
(
REMOVE_ITEM TGT_CUDA_ARCHS
_
"9.0"
)
list
(
REMOVE_ITEM
_
TGT_CUDA_ARCHS
"9.0"
)
set
(
_CUDA_ARCHS
"9.0a"
)
set
(
_CUDA_ARCHS
"9.0a"
)
endif
()
endif
()
endif
()
endif
()
if
(
"10.0a"
IN_LIST SRC_CUDA_ARCHS
)
if
(
"10.0a"
IN_LIST
_
SRC_CUDA_ARCHS
)
list
(
REMOVE_ITEM SRC_CUDA_ARCHS
"10.0a"
)
list
(
REMOVE_ITEM
_
SRC_CUDA_ARCHS
"10.0a"
)
if
(
"10.0"
IN_LIST TGT_CUDA_ARCHS
)
if
(
"10.0"
IN_LIST TGT_CUDA_ARCHS
)
list
(
REMOVE_ITEM TGT_CUDA_ARCHS
_
"10.0"
)
list
(
REMOVE_ITEM
_
TGT_CUDA_ARCHS
"10.0"
)
set
(
_CUDA_ARCHS
"10.0a"
)
set
(
_CUDA_ARCHS
"10.0a"
)
endif
()
endif
()
endif
()
endif
()
list
(
SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING
)
list
(
SORT
_
SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING
)
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
# is less or equal to ARCH (but has the same major version since SASS binary
# is less or equal to ARCH (but has the same major version since SASS binary
# compatibility is only forward compatible within the same major version).
# compatibility is only forward compatible within the same major version).
foreach
(
_ARCH
${
TGT_CUDA_ARCHS
_
}
)
foreach
(
_ARCH
${
_
TGT_CUDA_ARCHS
}
)
set
(
_TMP_ARCH
)
set
(
_TMP_ARCH
)
# Extract the major version of the target arch
# Extract the major version of the target arch
string
(
REGEX REPLACE
"^([0-9]+)
\\
..*$"
"
\\
1"
TGT_ARCH_MAJOR
"
${
_ARCH
}
"
)
string
(
REGEX REPLACE
"^([0-9]+)
\\
..*$"
"
\\
1"
TGT_ARCH_MAJOR
"
${
_ARCH
}
"
)
foreach
(
_SRC_ARCH
${
SRC_CUDA_ARCHS
}
)
foreach
(
_SRC_ARCH
${
_
SRC_CUDA_ARCHS
}
)
# Extract the major version of the source arch
# Extract the major version of the source arch
string
(
REGEX REPLACE
"^([0-9]+)
\\
..*$"
"
\\
1"
SRC_ARCH_MAJOR
"
${
_SRC_ARCH
}
"
)
string
(
REGEX REPLACE
"^([0-9]+)
\\
..*$"
"
\\
1"
SRC_ARCH_MAJOR
"
${
_SRC_ARCH
}
"
)
# Check
major-version match AND version-less-or-equal
# Check
version-less-or-equal, and allow PTX arches to match across majors
if
(
_SRC_ARCH VERSION_LESS_EQUAL _ARCH
)
if
(
_SRC_ARCH VERSION_LESS_EQUAL _ARCH
)
if
(
SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR
)
if
(
_SRC_ARCH IN_LIST _PTX_ARCHS OR
SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR
)
set
(
_TMP_ARCH
"
${
_SRC_ARCH
}
"
)
set
(
_TMP_ARCH
"
${
_SRC_ARCH
}
"
)
endif
()
endif
()
else
()
else
()
...
@@ -326,6 +363,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
...
@@ -326,6 +363,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
endforeach
()
endforeach
()
list
(
REMOVE_DUPLICATES _CUDA_ARCHS
)
list
(
REMOVE_DUPLICATES _CUDA_ARCHS
)
# reapply +PTX suffix to architectures that requested PTX
set
(
_FINAL_ARCHS
)
foreach
(
_arch
${
_CUDA_ARCHS
}
)
if
(
_arch IN_LIST _PTX_ARCHS
)
list
(
APPEND _FINAL_ARCHS
"
${
_arch
}
+PTX"
)
else
()
list
(
APPEND _FINAL_ARCHS
"
${
_arch
}
"
)
endif
()
endforeach
()
set
(
_CUDA_ARCHS
${
_FINAL_ARCHS
}
)
set
(
${
OUT_CUDA_ARCHS
}
${
_CUDA_ARCHS
}
PARENT_SCOPE
)
set
(
${
OUT_CUDA_ARCHS
}
${
_CUDA_ARCHS
}
PARENT_SCOPE
)
endfunction
()
endfunction
()
...
...
Prev
1
2
3
4
5
6
7
8
9
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment