Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7c73ceb5
Unverified
Commit
7c73ceb5
authored
Dec 21, 2025
by
Jinzhen Lin
Committed by
GitHub
Dec 20, 2025
Browse files
[Quantization] add marlin w4a8/w8a8 check (#31061)
Signed-off-by:
Jinzhen Lin
<
jinzhen.ljz@antgroup.com
>
parent
ae0770fa
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
0 deletions
+28
-0
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+12
-0
vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
...el_executor/layers/quantization/utils/marlin_utils_fp4.py
+12
-0
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+4
-0
No files found.
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
7c73ceb5
...
@@ -594,9 +594,15 @@ def apply_awq_marlin_linear(
...
@@ -594,9 +594,15 @@ def apply_awq_marlin_linear(
a_scales
=
None
a_scales
=
None
if
input_dtype
==
torch
.
int8
:
if
input_dtype
==
torch
.
int8
:
assert
quant_type
==
scalar_types
.
uint4
,
(
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x
,
a_scales
=
marlin_quant_input
(
reshaped_x
,
input_dtype
)
reshaped_x
,
a_scales
=
marlin_quant_input
(
reshaped_x
,
input_dtype
)
a_scales
=
a_scales
*
input_global_scale
a_scales
=
a_scales
*
input_global_scale
elif
input_dtype
==
torch
.
float8_e4m3fn
:
elif
input_dtype
==
torch
.
float8_e4m3fn
:
assert
quant_type
==
scalar_types
.
uint4
,
(
"INT8 weight + FP8 activation is not supported."
)
reshaped_x
,
a_scales
=
marlin_quant_input
(
reshaped_x
,
input_dtype
)
reshaped_x
,
a_scales
=
marlin_quant_input
(
reshaped_x
,
input_dtype
)
output
=
ops
.
gptq_marlin_gemm
(
output
=
ops
.
gptq_marlin_gemm
(
...
@@ -649,9 +655,15 @@ def apply_rtn_marlin_linear(
...
@@ -649,9 +655,15 @@ def apply_rtn_marlin_linear(
a_scales
=
None
a_scales
=
None
if
input_dtype
==
torch
.
int8
:
if
input_dtype
==
torch
.
int8
:
assert
quant_type
==
scalar_types
.
uint4b8
,
(
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x
,
a_scales
=
marlin_quant_input
(
reshaped_x
,
input_dtype
)
reshaped_x
,
a_scales
=
marlin_quant_input
(
reshaped_x
,
input_dtype
)
a_scales
=
a_scales
*
input_global_scale
a_scales
=
a_scales
*
input_global_scale
elif
input_dtype
==
torch
.
float8_e4m3fn
:
elif
input_dtype
==
torch
.
float8_e4m3fn
:
assert
quant_type
==
scalar_types
.
uint4b8
,
(
"INT8 weight + FP8 activation is not supported."
)
reshaped_x
,
a_scales
=
marlin_quant_input
(
reshaped_x
,
input_dtype
)
reshaped_x
,
a_scales
=
marlin_quant_input
(
reshaped_x
,
input_dtype
)
output
=
ops
.
gptq_marlin_gemm
(
output
=
ops
.
gptq_marlin_gemm
(
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
View file @
7c73ceb5
...
@@ -154,6 +154,12 @@ def prepare_fp4_layer_for_marlin(
...
@@ -154,6 +154,12 @@ def prepare_fp4_layer_for_marlin(
)
)
is_nvfp4
=
hasattr
(
layer
,
"weight_scale_2"
)
is_nvfp4
=
hasattr
(
layer
,
"weight_scale_2"
)
if
input_dtype
is
not
None
and
input_dtype
.
itemsize
==
1
:
if
is_nvfp4
:
raise
RuntimeError
(
"NVFP4 weight + INT8/FP8 activation is not supported."
)
elif
input_dtype
!=
torch
.
float8_e4m3fn
:
raise
RuntimeError
(
"MXFP4 weight + INT8 activation is not supported."
)
group_size
=
16
if
is_nvfp4
else
32
group_size
=
16
if
is_nvfp4
else
32
part_size_n
=
layer
.
output_size_per_partition
part_size_n
=
layer
.
output_size_per_partition
...
@@ -231,6 +237,12 @@ def prepare_moe_fp4_layer_for_marlin(
...
@@ -231,6 +237,12 @@ def prepare_moe_fp4_layer_for_marlin(
)
)
is_nvfp4
=
hasattr
(
layer
,
"w13_weight_scale_2"
)
is_nvfp4
=
hasattr
(
layer
,
"w13_weight_scale_2"
)
if
input_dtype
is
not
None
and
input_dtype
.
itemsize
==
1
:
if
is_nvfp4
:
raise
RuntimeError
(
"NVFP4 weight + INT8/FP8 activation is not supported."
)
elif
input_dtype
!=
torch
.
float8_e4m3fn
:
raise
RuntimeError
(
"MXFP4 weight + INT8 activation is not supported."
)
group_size
=
16
if
is_nvfp4
else
32
group_size
=
16
if
is_nvfp4
else
32
e
=
layer
.
num_experts
e
=
layer
.
num_experts
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
7c73ceb5
...
@@ -99,6 +99,8 @@ def prepare_fp8_layer_for_marlin(
...
@@ -99,6 +99,8 @@ def prepare_fp8_layer_for_marlin(
"be used leveraging the Marlin kernel. This may degrade "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
"performance for compute-heavy workloads."
)
)
if
input_dtype
is
not
None
and
input_dtype
.
itemsize
==
1
:
raise
RuntimeError
(
"Marlin W8A8 is not supported."
)
part_size_n
=
layer
.
output_size_per_partition
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
...
@@ -206,6 +208,8 @@ def prepare_moe_fp8_layer_for_marlin(
...
@@ -206,6 +208,8 @@ def prepare_moe_fp8_layer_for_marlin(
"be used leveraging the Marlin kernel. This may degrade "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
"performance for compute-heavy workloads."
)
)
if
input_dtype
is
not
None
and
input_dtype
.
itemsize
==
1
:
raise
RuntimeError
(
"Marlin W8A8 is not supported."
)
e
=
layer
.
num_experts
e
=
layer
.
num_experts
k
=
layer
.
hidden_size
k
=
layer
.
hidden_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment