Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0c738b58
Unverified
Commit
0c738b58
authored
Dec 17, 2025
by
Bowen Bao
Committed by
GitHub
Dec 18, 2025
Browse files
[Quantization] Support Quark int4-fp8 w4a8 for MoE (#30071)
Signed-off-by:
Bowen Bao
<
bowenbao@amd.com
>
parent
5a3adf58
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
201 additions
and
2 deletions
+201
-2
vllm/model_executor/layers/quantization/quark/quark.py
vllm/model_executor/layers/quantization/quark/quark.py
+43
-0
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+158
-2
No files found.
vllm/model_executor/layers/quantization/quark/quark.py
View file @
0c738b58
...
...
@@ -218,6 +218,49 @@ class QuarkConfig(QuantizationConfig):
else
:
return
False
def
_is_fp8_w4a8
(
self
,
weight_quant
:
list
[
dict
[
str
,
Any
]]
|
None
,
input_quant
:
dict
[
str
,
Any
]
|
None
,
)
->
bool
:
# Confirm weights and input quantized.
if
weight_quant
is
None
or
input_quant
is
None
:
return
False
if
not
isinstance
(
weight_quant
,
list
)
or
len
(
weight_quant
)
!=
2
:
return
False
# Confirm weight scheme is supported
is_w4a8_dtype
=
(
weight_quant
[
0
].
get
(
"dtype"
)
==
"fp8_e4m3"
and
weight_quant
[
1
].
get
(
"dtype"
)
==
"int4"
and
input_quant
.
get
(
"dtype"
)
==
"fp8_e4m3"
)
is_static_weight
=
not
weight_quant
[
0
].
get
(
"is_dynamic"
)
and
not
weight_quant
[
1
].
get
(
"is_dynamic"
)
is_per_tensor_fp8_and_per_channel_int4_weight
=
(
weight_quant
[
0
].
get
(
"qscheme"
)
==
"per_tensor"
and
weight_quant
[
1
].
get
(
"qscheme"
)
==
"per_channel"
and
weight_quant
[
1
].
get
(
"symmetric"
)
is
True
and
weight_quant
[
1
].
get
(
"ch_axis"
)
==
0
)
if
not
(
is_w4a8_dtype
and
is_static_weight
and
is_per_tensor_fp8_and_per_channel_int4_weight
):
return
False
# Dynamic quantization is always supported if weights supported.
if
input_quant
.
get
(
"is_dynamic"
):
return
True
# Confirm activation scheme is supported.
is_per_tensor_activation
=
input_quant
.
get
(
"qscheme"
)
==
"per_tensor"
return
is_per_tensor_activation
def
_is_fp8_w8a8
(
self
,
weight_quant
:
dict
[
str
,
Any
]
|
None
,
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
0c738b58
...
...
@@ -63,8 +63,9 @@ class QuarkMoEMethod(FusedMoEMethodBase):
)
weight_config
=
layer_quant_config
.
get
(
"weight"
)
input_config
=
layer_quant_config
.
get
(
"input_tensors"
)
if
quant_config
.
_is_fp8_w8a8
(
weight_config
,
input_config
):
if
quant_config
.
_is_fp8_w4a8
(
weight_config
,
input_config
):
return
QuarkW4A8Fp8MoEMethod
(
weight_config
,
input_config
,
module
.
moe_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_config
,
input_config
):
return
QuarkW8A8Fp8MoEMethod
(
weight_config
,
input_config
,
module
.
moe_config
)
elif
quant_config
.
_is_ocp_mx
(
weight_config
,
input_config
):
return
QuarkOCP_MX_MoEMethod
(
weight_config
,
input_config
,
module
.
moe_config
)
...
...
@@ -396,6 +397,161 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
)
class
QuarkW4A8Fp8MoEMethod
(
QuarkMoEMethod
):
def
__init__
(
self
,
weight_config
:
dict
[
str
,
Any
],
input_config
:
dict
[
str
,
Any
],
moe
:
FusedMoEConfig
,
):
super
().
__init__
(
moe
)
self
.
weight_quant
=
weight_config
self
.
input_quant
=
input_config
assert
rocm_aiter_ops
.
is_fused_moe_enabled
(),
(
"W4A8 FP8 MoE requires ROCm AITER fused MoE support."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
params_dtype
=
torch
.
uint32
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
//
8
,
# INT32 packing for W4
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
//
8
,
# INT32 packing for W4
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# Per-tensor fp8 weight scales
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# Per-channel int4 weight scales
w13_weight_scale_2
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
w2_weight_scale_2
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale_2"
,
w13_weight_scale_2
)
layer
.
register_parameter
(
"w2_weight_scale_2"
,
w2_weight_scale_2
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
set_weight_attrs
(
w13_weight_scale_2
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale_2
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
shuffled_w13
,
shuffled_w2
=
rocm_aiter_ops
.
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
assert
torch
.
all
(
max_w13_scales
!=
0
),
"fp8 weight scale cannot be zero."
for
expert_id
in
range
(
layer
.
local_num_experts
):
start
=
0
max_w13_scale_fp8
=
max_w13_scales
[
expert_id
]
for
shard_id
in
range
(
2
):
if
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
!=
max_w13_scale_fp8
:
int4_rescale
=
(
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
/
max_w13_scale_fp8
)
layer
.
w13_weight_scale_2
[
expert_id
][
start
:
start
+
shard_size
]
*=
(
int4_rescale
)
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post
# GEMM scaling optimal design - shall apply per-column weight_scale1 before
# GEMM, and weight_scale post
for
expert_id
in
range
(
layer
.
local_num_experts
):
layer
.
w13_weight_scale_2
[
expert_id
]
*=
max_w13_scales
[
expert_id
]
layer
.
w2_weight_scale_2
[
expert_id
]
*=
layer
.
w2_weight_scale
[
expert_id
]
def
get_fused_moe_quant_config
(
self
,
layer
):
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale_2
,
w2_scale
=
layer
.
w2_weight_scale_2
,
per_out_ch_quant
=
True
,
)
def
apply
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
topk_weights
,
topk_ids
,
_
=
layer
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
rocm_aiter_fused_experts
,
)
return
rocm_aiter_fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
quant_config
=
self
.
moe_quant_config
,
expert_map
=
layer
.
expert_map
,
)
class
QuarkOCP_MX_MoEMethod
(
QuarkMoEMethod
):
def
__init__
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment