Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
827268e9
Unverified
Commit
827268e9
authored
Apr 10, 2026
by
PikaPikachu
Committed by
GitHub
Apr 09, 2026
Browse files
[Quantization] Support Quark W8A8 INT8 MoE inference (#36320)
Signed-off-by:
kangletian
<
Letian.Kang@amd.com
>
parent
56e19d7e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
360 additions
and
2 deletions
+360
-2
tests/quantization/test_quark.py
tests/quantization/test_quark.py
+31
-0
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+9
-2
vllm/model_executor/layers/quantization/quark/quark.py
vllm/model_executor/layers/quantization/quark/quark.py
+38
-0
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+282
-0
No files found.
tests/quantization/test_quark.py
View file @
827268e9
...
@@ -22,6 +22,9 @@ from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
...
@@ -22,6 +22,9 @@ from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkW8A8Fp8
,
QuarkW8A8Fp8
,
QuarkW8A8Int8
,
QuarkW8A8Int8
,
)
)
from
vllm.model_executor.layers.quantization.quark.quark_moe
import
(
# noqa: E501
QuarkW8A8Int8MoEMethod
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.reference_mxfp4
import
dq_mxfp4_torch
,
qdq_mxfp4_torch
from
.reference_mxfp4
import
dq_mxfp4_torch
,
qdq_mxfp4_torch
...
@@ -126,6 +129,34 @@ def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
...
@@ -126,6 +129,34 @@ def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
assert
output
assert
output
@
pytest
.
mark
.
parametrize
(
"tp"
,
[
1
])
def
test_quark_int8_w8a8_moe
(
vllm_runner
,
tp
):
"""Test W8A8 INT8 MoE quantization with a tiny Qwen3 MoE model."""
model_path
=
"nameistoken/tiny-qwen3-moe-w8a8-int8-quark"
with
vllm_runner
(
model_path
,
enforce_eager
=
True
,
tensor_parallel_size
=
tp
,
gpu_memory_utilization
=
0.1
,
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
# MoE experts should use QuarkW8A8Int8MoEMethod
moe
=
layer
.
mlp
.
experts
assert
isinstance
(
moe
.
quant_method
,
QuarkW8A8Int8MoEMethod
),
(
f
"Expected QuarkW8A8Int8MoEMethod, got
{
type
(
moe
.
quant_method
)
}
"
)
# Non-MoE linear layers should use QuarkW8A8Int8
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
scheme
,
QuarkW8A8Int8
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello"
,
max_tokens
=
4
)
assert
output
def
test_quark_fp8_parity
(
vllm_runner
):
def
test_quark_fp8_parity
(
vllm_runner
):
quark_model_id
=
"amd-quark/llama-tiny-fp8-quark-quant-method"
quark_model_id
=
"amd-quark/llama-tiny-fp8-quark-quant-method"
fp8_model_id
=
"amd-quark/llama-tiny-fp8-quant-method"
fp8_model_id
=
"amd-quark/llama-tiny-fp8-quant-method"
...
...
vllm/model_executor/layers/fused_moe/utils.py
View file @
827268e9
...
@@ -163,8 +163,15 @@ def _int8_quantize(
...
@@ -163,8 +163,15 @@ def _int8_quantize(
# activations apply per-token quantization. Otherwise, assume
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static
# activation tensor-wise fp8/int8 quantization, dynamic or static
if
block_shape
is
None
:
if
block_shape
is
None
:
assert
per_act_token
,
"int8 quantization only supports block or channel-wise"
if
per_act_token
:
A
,
A_scale
=
per_token_quant_int8
(
A
)
A
,
A_scale
=
per_token_quant_int8
(
A
)
elif
A_scale
is
not
None
:
# Static per-tensor: use the optimized CUDA kernel
A
,
A_scale
,
_
=
ops
.
scaled_int8_quant
(
A
,
scale
=
A_scale
)
elif
A_scale
is
None
:
# Dynamic per-tensor: compute scale then quantize via kernel
A_scale
=
torch
.
clamp
(
A
.
abs
().
max
()
/
127.0
,
min
=
1e-10
)
A
,
A_scale
,
_
=
ops
.
scaled_int8_quant
(
A
,
scale
=
A_scale
)
else
:
else
:
assert
not
per_act_token
assert
not
per_act_token
assert
len
(
block_shape
)
==
2
assert
len
(
block_shape
)
==
2
...
...
vllm/model_executor/layers/quantization/quark/quark.py
View file @
827268e9
...
@@ -389,6 +389,37 @@ class QuarkConfig(QuantizationConfig):
...
@@ -389,6 +389,37 @@ class QuarkConfig(QuantizationConfig):
return
is_weight_mxfp4
and
is_input_fp8
return
is_weight_mxfp4
and
is_input_fp8
def
_is_dynamic_per_token_w8a8
(
self
,
weight_quant
:
dict
[
str
,
Any
]
|
None
,
input_quant
:
dict
[
str
,
Any
]
|
None
,
)
->
bool
:
"""Detect W8A8 INT8 with per-tensor or per-channel
weights and dynamic per-token input."""
if
weight_quant
is
None
or
input_quant
is
None
:
return
False
is_int8_dtype
=
(
weight_quant
.
get
(
"dtype"
)
==
"int8"
and
input_quant
.
get
(
"dtype"
)
==
"int8"
)
is_valid_weight_scheme
=
weight_quant
.
get
(
"qscheme"
)
in
[
"per_tensor"
,
"per_channel"
,
]
is_per_token_input
=
input_quant
.
get
(
"qscheme"
)
==
"per_channel"
is_dynamic_input
=
input_quant
.
get
(
"is_dynamic"
)
is
True
is_weight_symmetric
=
weight_quant
.
get
(
"symmetric"
)
is
True
return
(
is_int8_dtype
and
is_valid_weight_scheme
and
is_per_token_input
and
is_dynamic_input
and
is_weight_symmetric
)
def
_is_w_ocp_mx_a_x
(
def
_is_w_ocp_mx_a_x
(
self
,
weight_quant
:
dict
[
str
,
Any
]
|
None
,
input_quant
:
dict
[
str
,
Any
]
|
None
self
,
weight_quant
:
dict
[
str
,
Any
]
|
None
,
input_quant
:
dict
[
str
,
Any
]
|
None
)
->
bool
:
)
->
bool
:
...
@@ -556,6 +587,13 @@ class QuarkConfig(QuantizationConfig):
...
@@ -556,6 +587,13 @@ class QuarkConfig(QuantizationConfig):
)
)
if
is_w4a8_supported
:
if
is_w4a8_supported
:
return
QuarkW4A8_MXFP4_FP8
(
weight_config
,
input_config
)
return
QuarkW4A8_MXFP4_FP8
(
weight_config
,
input_config
)
elif
self
.
_is_dynamic_per_token_w8a8
(
weight_config
,
input_config
):
weight_qscheme
=
cast
(
str
,
weight_config
.
get
(
"qscheme"
))
return
QuarkW8A8Int8
(
qscheme
=
weight_qscheme
,
is_static_input_scheme
=
False
,
input_symmetric
=
input_config
.
get
(
"symmetric"
),
)
elif
self
.
_is_w_ocp_mx_a_x
(
weight_config
,
input_config
):
elif
self
.
_is_w_ocp_mx_a_x
(
weight_config
,
input_config
):
return
QuarkOCP_MX
(
return
QuarkOCP_MX
(
weight_config
,
input_config
,
dynamic_mxfp4_quant
=
dynamic_mxfp4_quant
weight_config
,
input_config
,
dynamic_mxfp4_quant
=
dynamic_mxfp4_quant
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
827268e9
...
@@ -109,6 +109,12 @@ class QuarkMoEMethod(FusedMoEMethodBase):
...
@@ -109,6 +109,12 @@ class QuarkMoEMethod(FusedMoEMethodBase):
return
QuarkOCP_MX_MoEMethod
(
return
QuarkOCP_MX_MoEMethod
(
weight_config
,
input_config
,
module
.
moe_config
weight_config
,
input_config
,
module
.
moe_config
)
)
elif
quant_config
.
_is_static_tensor_w8a8
(
weight_config
,
input_config
)
or
quant_config
.
_is_dynamic_per_token_w8a8
(
weight_config
,
input_config
):
return
QuarkW8A8Int8MoEMethod
(
weight_config
,
input_config
,
module
.
moe_config
)
else
:
else
:
raise
RuntimeError
(
"Unsupported FusedMoe scheme"
)
raise
RuntimeError
(
"Unsupported FusedMoe scheme"
)
...
@@ -505,6 +511,282 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -505,6 +511,282 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
)
)
class
QuarkW8A8Int8MoEMethod
(
QuarkMoEMethod
):
"""Quark W8A8 INT8 MoE method."""
def
__init__
(
self
,
weight_config
:
dict
[
str
,
Any
],
input_config
:
dict
[
str
,
Any
],
moe
:
FusedMoEConfig
,
):
super
().
__init__
(
moe
)
self
.
weight_quant
=
weight_config
self
.
input_quant
=
input_config
self
.
weight_qscheme
=
self
.
weight_quant
.
get
(
"qscheme"
,
"per_tensor"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
get
(
"is_dynamic"
,
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
params_dtype
=
torch
.
int8
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
if
self
.
weight_qscheme
==
"per_channel"
:
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
else
:
# per-tensor: one scalar per expert
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
static_input_scales
:
w13_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
extra_weight_attrs
)
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
else
:
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
# ZERO POINTS (loaded but discarded after loading; kernel uses symmetric)
w13_input_zero_point
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_input_zero_point"
,
w13_input_zero_point
)
set_weight_attrs
(
w13_input_zero_point
,
extra_weight_attrs
)
w2_input_zero_point
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_input_zero_point"
,
w2_input_zero_point
)
set_weight_attrs
(
w2_input_zero_point
,
extra_weight_attrs
)
if
self
.
weight_qscheme
==
"per_channel"
:
w13_weight_zero_point
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
int8
,
),
requires_grad
=
False
,
)
w2_weight_zero_point
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
else
:
w13_weight_zero_point
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
w2_weight_zero_point
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_zero_point"
,
w13_weight_zero_point
)
set_weight_attrs
(
w13_weight_zero_point
,
extra_weight_attrs
)
layer
.
register_parameter
(
"w2_weight_zero_point"
,
w2_weight_zero_point
)
set_weight_attrs
(
w2_weight_zero_point
,
extra_weight_attrs
)
# BIAS
if
self
.
has_bias
:
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_bias"
,
w13_bias
)
set_weight_attrs
(
w13_bias
,
extra_weight_attrs
)
w2_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_bias"
,
w2_bias
)
set_weight_attrs
(
w2_bias
,
extra_weight_attrs
)
else
:
layer
.
w13_bias
,
layer
.
w2_bias
=
None
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Discard zero points (INT8 fused MoE kernel uses symmetric quant)
for
attr
in
(
"w13_input_zero_point"
,
"w2_input_zero_point"
,
"w13_weight_zero_point"
,
"w2_weight_zero_point"
,
):
if
hasattr
(
layer
,
attr
):
delattr
(
layer
,
attr
)
# For static input scales, collapse per-expert scales to single max
if
self
.
static_input_scales
:
if
layer
.
w13_input_scale
is
None
or
layer
.
w2_input_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
not
all_close_1d
(
layer
.
w13_input_scale
)
or
not
all_close_1d
(
layer
.
w2_input_scale
):
logger
.
warning_once
(
"Found input_scales that are not equal for "
"INT8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# For per-tensor weights, merge w1/w3 scales into single per-expert
if
self
.
weight_qscheme
==
"per_tensor"
:
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
local_num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
_
=
(
ops
.
scaled_int8_quant
(
dq_weight
,
scale
=
max_w13_scales
[
expert_id
],
)
)
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
is_dynamic
=
not
self
.
static_input_scales
is_per_channel
=
self
.
weight_qscheme
==
"per_channel"
return
FusedMoEQuantConfig
.
make
(
torch
.
int8
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
w1_bias
=
getattr
(
layer
,
"w13_bias"
,
None
),
w2_bias
=
getattr
(
layer
,
"w2_bias"
,
None
),
per_act_token_quant
=
is_dynamic
,
per_out_ch_quant
=
is_per_channel
,
block_shape
=
None
,
)
def
apply
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
not
self
.
moe
.
disable_inplace
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
quant_config
=
self
.
moe_quant_config
,
)
class
QuarkW4A8Fp8MoEMethod
(
QuarkMoEMethod
):
class
QuarkW4A8Fp8MoEMethod
(
QuarkMoEMethod
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment