Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4671ac6e
Unverified
Commit
4671ac6e
authored
Jun 23, 2025
by
22quinn
Committed by
GitHub
Jun 24, 2025
Browse files
[Bugfix][Benchmark] Fix Marlin benchmark (#19929)
parent
dd2ccf8d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
150 additions
and
79 deletions
+150
-79
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+150
-79
No files found.
benchmarks/kernels/benchmark_marlin.py
View file @
4671ac6e
...
@@ -22,8 +22,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...
@@ -22,8 +22,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES
,
MARLIN_SUPPORTED_GROUP_SIZES
,
query_marlin_supported_quant_types
,
query_marlin_supported_quant_types
,
)
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
FP4_MARLIN_SUPPORTED_GROUP_SIZES
,
rand_marlin_weight_fp4_like
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
marlin_quant_fp8_torch
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
MarlinWorkspace
,
awq_marlin_quantize
,
marlin_quantize
,
marlin_quantize
,
)
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
...
@@ -35,7 +43,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -35,7 +43,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights
,
quantize_weights
,
sort_weights
,
sort_weights
,
)
)
from
vllm.scalar_type
import
ScalarType
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
...
@@ -57,36 +65,79 @@ def bench_run(
...
@@ -57,36 +65,79 @@ def bench_run(
size_n
:
int
,
size_n
:
int
,
):
):
label
=
"Quant Matmul"
label
=
"Quant Matmul"
sub_label
=
"{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})"
.
format
(
sub_label
=
"{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})"
.
format
(
model
,
act_order
,
is_k_full
,
str
(
quant_type
),
group_size
,
size_m
,
size_k
,
size_n
model
,
act_order
,
is_k_full
,
str
(
quant_type
),
group_size
,
size_m
,
size_k
,
size_n
)
)
print
(
f
"Testing:
{
sub_label
}
"
)
print
(
f
"Testing:
{
sub_label
}
"
)
a
=
torch
.
randn
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
()
a
=
torch
.
randn
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
()
b
=
torch
.
rand
(
size_k
,
size_n
).
to
(
torch
.
half
).
cuda
()
b
=
torch
.
rand
(
size_k
,
size_n
).
to
(
torch
.
half
).
cuda
()
has_zp
=
quant_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
if
act_order
and
(
group_size
==
-
1
or
group_size
==
size_k
or
has_zp
):
return
if
size_k
%
group_size
!=
0
:
return
a_tmp
=
torch
.
zeros
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
()
marlin_24_supported
=
(
quant_type
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and
group_size
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
repack_supported
=
(
quant_type
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and
group_size
in
MARLIN_SUPPORTED_GROUP_SIZES
)
allspark_supported
=
(
quant_type
in
ALLSPARK_SUPPORTED_QUANT_TYPES
and
group_size
==
-
1
and
not
act_order
and
is_k_full
)
def
gen_marlin_params
():
# Marlin quant
# Marlin quant
(
marlin_g_idx
=
marlin_sort_indices
=
marlin_zp
=
marlin_s2
=
None
if
quant_type
==
scalar_types
.
float4_e2m1f
:
if
group_size
!=
16
or
act_order
:
return
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
=
rand_marlin_weight_fp4_like
(
b
.
T
,
group_size
)
elif
quant_type
==
scalar_types
.
float8_e4m3fn
:
if
group_size
not
in
[
-
1
,
128
]
or
act_order
:
return
marlin_w_ref
,
marlin_q_w
,
marlin_s
=
marlin_quant_fp8_torch
(
b
.
T
,
group_size
)
elif
group_size
==
16
:
return
elif
has_zp
:
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
=
awq_marlin_quantize
(
b
,
quant_type
,
group_size
)
else
:
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_g_idx
,
marlin_sort_indices
,
_
=
(
marlin_quantize
(
b
,
quant_type
,
group_size
,
act_order
)
)
return
(
marlin_w_ref
,
marlin_w_ref
,
marlin_q_w
,
marlin_q_w
,
marlin_s
,
marlin_s
,
marlin_s2
,
marlin_zp
,
marlin_g_idx
,
marlin_g_idx
,
marlin_sort_indices
,
marlin_sort_indices
,
marlin_rand_perm
,
)
)
=
marlin_quantize
(
b
,
quant_type
,
group_size
,
act_order
)
# Marlin_24 quant
def
gen_marlin_24_params
():
marlin_24_w_ref
=
marlin_24_q_w_comp
=
marlin_24_meta
=
marlin_24_s
=
None
if
marlin_24_supported
:
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
=
(
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
=
(
marlin_24_quantize
(
b
,
quant_type
,
group_size
)
marlin_24_quantize
(
b
,
quant_type
,
group_size
)
)
)
return
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
marlin_zp
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b
.
device
)
def
gen_repack_params
():
q_w_gptq
=
None
# GPTQ quant
repack_sort_indices
=
None
if
repack_supported
:
(
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
)
=
gptq_quantize_weights
(
(
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
)
=
gptq_quantize_weights
(
b
,
quant_type
,
group_size
,
act_order
b
,
quant_type
,
group_size
,
act_order
)
)
...
@@ -97,33 +148,21 @@ def bench_run(
...
@@ -97,33 +148,21 @@ def bench_run(
repack_sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b
.
device
)
repack_sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b
.
device
)
if
act_order
:
if
act_order
:
(
q_w
,
g_idx
,
repack_sort_indices
)
=
sort_weights
(
q_w
,
g_idx
)
(
q_w
,
g_idx
,
repack_sort_indices
)
=
sort_weights
(
q_w
,
g_idx
)
return
q_w_gptq
,
repack_sort_indices
# Prepare
marlin_workspace
=
MarlinWorkspace
(
def
gen_allspark_params
():
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
qw_reorder
=
s_reorder
=
zp_reorder
=
sm_count
=
sm_version
=
(
)
CUBLAS_M_THRESHOLD
)
=
None
marlin_24_workspace
=
MarlinWorkspace
(
nonlocal
allspark_supported
size_n
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
if
allspark_supported
:
)
marlin_zp
=
torch
.
zeros_like
(
marlin_s
,
dtype
=
torch
.
int
)
# AllSpark W8A16 quant
as_supported_case
=
(
quant_type
in
ALLSPARK_SUPPORTED_QUANT_TYPES
and
group_size
==
-
1
and
not
act_order
and
is_k_full
)
if
as_supported_case
:
properties
=
torch
.
cuda
.
get_device_properties
(
b
.
device
.
index
)
properties
=
torch
.
cuda
.
get_device_properties
(
b
.
device
.
index
)
sm_count
=
properties
.
multi_processor_count
sm_count
=
properties
.
multi_processor_count
sm_version
=
properties
.
major
*
10
+
properties
.
minor
sm_version
=
properties
.
major
*
10
+
properties
.
minor
supported_arch
=
sm_version
>=
80
and
sm_version
<
90
supported_arch
=
sm_version
>=
80
and
sm_version
<
90
as_supported_case
=
as
_supported
_case
and
supported_arch
allspark_supported
=
allspark
_supported
and
supported_arch
if
supported_arch
:
if
supported_arch
:
has_zp
=
False
w_ref
,
qw
,
s
,
zp
=
quantize_weights
(
b
,
quant_type
,
group_size
,
has_zp
)
w_ref
,
qw
,
s
,
zp
=
quantize_weights
(
b
,
quant_type
,
group_size
,
has_zp
)
qw
=
qw
.
to
(
torch
.
uint8
)
qw
=
qw
.
to
(
torch
.
uint8
)
...
@@ -131,6 +170,39 @@ def bench_run(
...
@@ -131,6 +170,39 @@ def bench_run(
qw
,
s
,
zp
,
has_zp
qw
,
s
,
zp
,
has_zp
)
)
CUBLAS_M_THRESHOLD
=
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
CUBLAS_M_THRESHOLD
=
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
return
(
qw_reorder
,
s_reorder
,
zp_reorder
,
sm_count
,
sm_version
,
CUBLAS_M_THRESHOLD
,
)
(
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
,
marlin_zp
,
marlin_g_idx
,
marlin_sort_indices
,
)
=
gen_marlin_params
()
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
=
(
gen_marlin_24_params
()
)
q_w_gptq
,
repack_sort_indices
=
gen_repack_params
()
qw_reorder
,
s_reorder
,
zp_reorder
,
sm_count
,
sm_version
,
CUBLAS_M_THRESHOLD
=
(
gen_allspark_params
()
)
# Prepare
marlin_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
marlin_24_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
)
globals
=
{
globals
=
{
# Gen params
# Gen params
...
@@ -140,15 +212,14 @@ def bench_run(
...
@@ -140,15 +212,14 @@ def bench_run(
"size_n"
:
size_n
,
"size_n"
:
size_n
,
"size_k"
:
size_k
,
"size_k"
:
size_k
,
"a"
:
a
,
"a"
:
a
,
"a_tmp"
:
a_tmp
,
# Marlin params
# Marlin params
"marlin_w_ref"
:
marlin_w_ref
,
"marlin_w_ref"
:
marlin_w_ref
,
"marlin_q_w"
:
marlin_q_w
,
"marlin_q_w"
:
marlin_q_w
,
"marlin_s"
:
marlin_s
,
"marlin_s"
:
marlin_s
,
"marlin_s2"
:
marlin_s2
,
"marlin_zp"
:
marlin_zp
,
"marlin_zp"
:
marlin_zp
,
"marlin_g_idx"
:
marlin_g_idx
,
"marlin_g_idx"
:
marlin_g_idx
,
"marlin_sort_indices"
:
marlin_sort_indices
,
"marlin_sort_indices"
:
marlin_sort_indices
,
"marlin_rand_perm"
:
marlin_rand_perm
,
"marlin_workspace"
:
marlin_workspace
,
"marlin_workspace"
:
marlin_workspace
,
"is_k_full"
:
is_k_full
,
"is_k_full"
:
is_k_full
,
# Marlin_24 params
# Marlin_24 params
...
@@ -161,12 +232,12 @@ def bench_run(
...
@@ -161,12 +232,12 @@ def bench_run(
"q_w_gptq"
:
q_w_gptq
,
"q_w_gptq"
:
q_w_gptq
,
"repack_sort_indices"
:
repack_sort_indices
,
"repack_sort_indices"
:
repack_sort_indices
,
# AllSpark W8A16 params
# AllSpark W8A16 params
"qw_reorder"
:
qw_reorder
if
as_supported_case
else
None
,
"qw_reorder"
:
qw_reorder
,
"s_reorder"
:
s_reorder
if
as_supported_case
else
None
,
"s_reorder"
:
s_reorder
,
"zp_reorder"
:
zp_reorder
if
as_supported_case
else
None
,
"zp_reorder"
:
zp_reorder
,
"sm_count"
:
sm_count
if
as_supported_case
else
None
,
"sm_count"
:
sm_count
,
"sm_version"
:
sm_version
if
as_supported_case
else
None
,
"sm_version"
:
sm_version
,
"CUBLAS_M_THRESHOLD"
:
CUBLAS_M_THRESHOLD
if
as_supported_case
else
None
,
"CUBLAS_M_THRESHOLD"
:
CUBLAS_M_THRESHOLD
,
# Kernels
# Kernels
"gptq_marlin_gemm"
:
ops
.
gptq_marlin_gemm
,
"gptq_marlin_gemm"
:
ops
.
gptq_marlin_gemm
,
"gptq_marlin_24_gemm"
:
ops
.
gptq_marlin_24_gemm
,
"gptq_marlin_24_gemm"
:
ops
.
gptq_marlin_24_gemm
,
...
@@ -177,7 +248,7 @@ def bench_run(
...
@@ -177,7 +248,7 @@ def bench_run(
min_run_time
=
1
min_run_time
=
1
# Warmup pytorch
# Warmup pytorch
for
i
in
range
(
5
):
for
_
in
range
(
5
):
torch
.
matmul
(
a
,
marlin_w_ref
)
torch
.
matmul
(
a
,
marlin_w_ref
)
results
.
append
(
results
.
append
(
...
@@ -192,17 +263,17 @@ def bench_run(
...
@@ -192,17 +263,17 @@ def bench_run(
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)"
,
# noqa: E501
stmt
=
"output = gptq_marlin_gemm(a,
None,
marlin_q_w, marlin_s,
marlin_s2,
marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)"
,
# noqa: E501
globals
=
globals
,
globals
=
globals
,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_gemm
_fp16
"
,
description
=
"gptq_marlin_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
)
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)"
,
# noqa: E501
stmt
=
"output = gptq_marlin_gemm(a,
None,
marlin_q_w, marlin_s,
marlin_s2,
marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)"
,
# noqa: E501
globals
=
globals
,
globals
=
globals
,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
...
@@ -210,10 +281,7 @@ def bench_run(
...
@@ -210,10 +281,7 @@ def bench_run(
).
blocked_autorange
(
min_run_time
=
min_run_time
)
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
)
if
(
if
marlin_24_supported
:
quant_type
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and
group_size
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
):
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)"
,
# noqa: E501
stmt
=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)"
,
# noqa: E501
...
@@ -224,6 +292,7 @@ def bench_run(
...
@@ -224,6 +292,7 @@ def bench_run(
).
blocked_autorange
(
min_run_time
=
min_run_time
)
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
)
if
repack_supported
:
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)"
,
# noqa: E501
stmt
=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)"
,
# noqa: E501
...
@@ -234,7 +303,7 @@ def bench_run(
...
@@ -234,7 +303,7 @@ def bench_run(
).
blocked_autorange
(
min_run_time
=
min_run_time
)
).
blocked_autorange
(
min_run_time
=
min_run_time
)
)
)
if
a
s
_supported
_case
:
if
a
llspark
_supported
:
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)"
,
# noqa: E501
stmt
=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)"
,
# noqa: E501
...
@@ -250,7 +319,6 @@ def main(args):
...
@@ -250,7 +319,6 @@ def main(args):
print
(
"Benchmarking models:"
)
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
print
(
f
"[
{
i
}
]
{
model
}
"
)
results
:
list
[
benchmark
.
Measurement
]
=
[]
results
:
list
[
benchmark
.
Measurement
]
=
[]
for
model
in
args
.
models
:
for
model
in
args
.
models
:
...
@@ -278,14 +346,17 @@ def main(args):
...
@@ -278,14 +346,17 @@ def main(args):
):
):
continue
continue
for
quant_type
in
query_marlin_supported_quant_types
(
False
):
for
quant_type
in
query_marlin_supported_quant_types
():
if
(
if
(
len
(
args
.
limit_num_bits
)
>
0
len
(
args
.
limit_num_bits
)
>
0
and
quant_type
.
size_bits
not
in
args
.
limit_num_bits
and
quant_type
.
size_bits
not
in
args
.
limit_num_bits
):
):
continue
continue
for
group_size
in
MARLIN_SUPPORTED_GROUP_SIZES
:
for
group_size
in
(
MARLIN_SUPPORTED_GROUP_SIZES
+
FP4_MARLIN_SUPPORTED_GROUP_SIZES
):
if
(
if
(
len
(
args
.
limit_group_size
)
>
0
len
(
args
.
limit_group_size
)
>
0
and
group_size
not
in
args
.
limit_group_size
and
group_size
not
in
args
.
limit_group_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment