Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6341d430
Unverified
Commit
6341d430
authored
Mar 13, 2026
by
Divakar Verma
Committed by
GitHub
Mar 13, 2026
Browse files
[ROCm][Quantization] add quark w4a8 mxfp4_fp8 for LinearLayer (#35316)
Signed-off-by:
Divakar Verma
<
divakar.verma@amd.com
>
parent
7afe0faa
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
311 additions
and
1 deletion
+311
-1
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+53
-0
vllm/model_executor/layers/quantization/quark/quark.py
vllm/model_executor/layers/quantization/quark/quark.py
+32
-0
vllm/model_executor/layers/quantization/quark/schemes/__init__.py
...el_executor/layers/quantization/quark/schemes/__init__.py
+8
-1
vllm/model_executor/layers/quantization/quark/schemes/quark_w4a8_mxfp4_fp8.py
...layers/quantization/quark/schemes/quark_w4a8_mxfp4_fp8.py
+218
-0
No files found.
vllm/_aiter_ops.py
View file @
6341d430
...
...
@@ -861,6 +861,39 @@ def _rocm_aiter_triton_add_rmsnorm_pad_fake(
return
out
,
residual_out
def
_rocm_aiter_gemm_a8wfp4_impl
(
x
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
,
w_scales
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
from
aiter.ops.triton.gemm_a8wfp4
import
gemm_a8wfp4
M
,
N
=
x
.
shape
[
0
],
w
.
shape
[
0
]
y
=
torch
.
empty
(
M
,
N
,
dtype
=
out_dtype
,
device
=
x
.
device
)
gemm_a8wfp4
(
x
=
x
,
w
=
w
,
y
=
y
,
x_scales
=
x_scales
,
w_scales
=
w_scales
,
dtype
=
out_dtype
,
config
=
None
,
)
return
y
def
_rocm_aiter_gemm_a8wfp4_fake
(
x
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
,
w_scales
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
return
torch
.
empty
(
x
.
shape
[
0
],
w
.
shape
[
0
],
dtype
=
out_dtype
,
device
=
x
.
device
)
def
_triton_rotary_embedding_impl
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
...
@@ -1337,6 +1370,14 @@ class rocm_aiter_ops:
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_gemm_a8wfp4"
,
op_func
=
_rocm_aiter_gemm_a8wfp4_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_gemm_a8wfp4_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
# Register rocm aiter rotary embedding custom op
direct_register_custom_op
(
op_name
=
"rocm_aiter_triton_rotary_embedding"
,
...
...
@@ -1646,6 +1687,18 @@ class rocm_aiter_ops:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
vllm
.
rocm_aiter_per_token_quant
(
x
,
quant_dtype
,
scale
)
@
staticmethod
def
gemm_a8wfp4
(
x
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
,
w_scales
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
rocm_aiter_gemm_a8wfp4
(
x
,
w
,
x_scales
,
w_scales
,
out_dtype
)
@
staticmethod
def
triton_fp4_gemm_dynamic_qaunt
(
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/quark/quark.py
View file @
6341d430
...
...
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E
from
vllm.model_executor.layers.quantization.quark.schemes
import
(
QuarkOCP_MX
,
QuarkScheme
,
QuarkW4A8_MXFP4_FP8
,
QuarkW8A8Fp8
,
QuarkW8A8Int8
,
)
...
...
@@ -350,6 +351,31 @@ class QuarkConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return
is_int8_dtype
and
is_tensor
and
is_weight_symmetric
and
is_static
def
_is_w4a8_mxfp4_fp8
(
self
,
weight_quant
:
dict
[
str
,
Any
]
|
None
,
input_quant
:
dict
[
str
,
Any
]
|
None
,
)
->
bool
:
if
weight_quant
is
None
or
input_quant
is
None
:
return
False
is_weight_mxfp4
=
(
weight_quant
.
get
(
"dtype"
)
==
"fp4"
and
weight_quant
.
get
(
"qscheme"
)
==
"per_group"
and
weight_quant
.
get
(
"group_size"
)
==
32
and
weight_quant
.
get
(
"scale_format"
)
==
"e8m0"
and
not
weight_quant
.
get
(
"is_dynamic"
)
)
is_input_fp8
=
(
input_quant
.
get
(
"dtype"
)
==
"fp8_e4m3"
and
input_quant
.
get
(
"qscheme"
)
==
"per_tensor"
and
not
input_quant
.
get
(
"is_dynamic"
)
# Static per-tensor
and
input_quant
.
get
(
"symmetric"
)
is
True
# Symmetric quantization
)
return
is_weight_mxfp4
and
is_input_fp8
def
_is_w_ocp_mx_a_x
(
self
,
weight_quant
:
dict
[
str
,
Any
]
|
None
,
input_quant
:
dict
[
str
,
Any
]
|
None
)
->
bool
:
...
...
@@ -504,6 +530,12 @@ class QuarkConfig(QuantizationConfig):
is_static_input_scheme
=
True
,
input_symmetric
=
input_config
.
get
(
"symmetric"
),
)
elif
self
.
_is_w4a8_mxfp4_fp8
(
weight_config
,
input_config
):
is_w4a8_supported
=
self
.
_check_scheme_supported
(
QuarkW4A8_MXFP4_FP8
.
get_min_capability
(),
error
=
False
)
if
is_w4a8_supported
:
return
QuarkW4A8_MXFP4_FP8
(
weight_config
,
input_config
)
elif
self
.
_is_w_ocp_mx_a_x
(
weight_config
,
input_config
):
return
QuarkOCP_MX
(
weight_config
,
input_config
,
dynamic_mxfp4_quant
=
dynamic_mxfp4_quant
...
...
vllm/model_executor/layers/quantization/quark/schemes/__init__.py
View file @
6341d430
...
...
@@ -3,7 +3,14 @@
from
.quark_ocp_mx
import
QuarkOCP_MX
from
.quark_scheme
import
QuarkScheme
from
.quark_w4a8_mxfp4_fp8
import
QuarkW4A8_MXFP4_FP8
from
.quark_w8a8_fp8
import
QuarkW8A8Fp8
from
.quark_w8a8_int8
import
QuarkW8A8Int8
__all__
=
[
"QuarkScheme"
,
"QuarkW8A8Fp8"
,
"QuarkW8A8Int8"
,
"QuarkOCP_MX"
]
__all__
=
[
"QuarkScheme"
,
"QuarkW8A8Fp8"
,
"QuarkW8A8Int8"
,
"QuarkOCP_MX"
,
"QuarkW4A8_MXFP4_FP8"
,
]
vllm/model_executor/layers/quantization/quark/schemes/quark_w4a8_mxfp4_fp8.py
0 → 100644
View file @
6341d430
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
fractions
import
Fraction
from
typing
import
Any
import
torch
import
torch.nn.functional
as
F
from
vllm._aiter_ops
import
is_aiter_found_and_supported
,
rocm_aiter_ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
get_fp8_min_max
,
)
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
,
PerTensorScaleParameter
,
)
from
vllm.platforms
import
current_platform
from
.quark_scheme
import
QuarkScheme
logger
=
init_logger
(
__name__
)
__all__
=
[
"QuarkW4A8_MXFP4_FP8"
]
OCP_MX_BLOCK_SIZE
=
32
class
QuarkW4A8_MXFP4_FP8
(
QuarkScheme
):
"""
- Weights: MXFP4 with E8M0 scales per block of 32
- Activations: FP8 E4M3 (static per-tensor quantization)
Uses the AITER Triton kernel and falls back to emulation if AITER not available.
"""
def
__init__
(
self
,
weight_quant_spec
:
dict
[
str
,
Any
],
input_quant_spec
:
dict
[
str
,
Any
],
):
self
.
out_dtype
=
None
self
.
weight_dtype
=
"mxfp4"
self
.
packed_factor
:
Fraction
=
Fraction
(
2
,
1
)
# 2 FP4 values per byte
self
.
weight_block_size
=
OCP_MX_BLOCK_SIZE
self
.
is_static_input_scheme
=
not
input_quant_spec
.
get
(
"is_dynamic"
)
self
.
input_qscheme
=
input_quant_spec
.
get
(
"qscheme"
)
# "per_tensor"
self
.
fp8_min
,
self
.
fp8_max
=
get_fp8_min_max
()
self
.
fp8_dtype
=
current_platform
.
fp8_dtype
()
if
not
self
.
is_static_input_scheme
:
raise
NotImplementedError
(
"Dynamic FP8 activation quantization is not yet supported "
"for W4A8. The current implementation expects static per-tensor "
"FP8 scales stored in the checkpoint."
)
kernel_supported_gpu
=
False
if
current_platform
.
is_rocm
():
from
vllm.platforms.rocm
import
on_gfx950
kernel_supported_gpu
=
on_gfx950
()
self
.
use_aiter_kernel
=
(
is_aiter_found_and_supported
()
and
self
.
is_static_input_scheme
and
kernel_supported_gpu
)
if
not
self
.
use_aiter_kernel
:
logger
.
warning_once
(
"[W4A8 MXFP4+FP8] Aiter Triton kernel not found. Using emulation mode."
)
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
70
def
get_packed_dim
(
self
,
dim
:
int
)
->
int
:
assert
dim
%
2
==
0
,
f
"Dimension
{
dim
}
must be even for MXFP4 packing"
return
dim
//
2
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
list
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
,
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
# MXFP4 WEIGHT (packed, 2 values per byte)
weight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
self
.
get_packed_dim
(
input_size_per_partition
),
dtype
=
torch
.
uint8
,
),
input_dim
=
1
,
output_dim
=
0
,
packed_dim
=
1
,
packed_factor
=
self
.
packed_factor
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
# WEIGHT SCALE (E8M0 format, per block of 32)
weight_scale
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
self
.
weight_block_size
,
dtype
=
torch
.
uint8
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE (FP8 per-tensor static scale)
if
self
.
is_static_input_scheme
:
input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
,
),
weight_loader
=
weight_loader
,
)
# Initialize to avoid NaN
input_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Ensuring weights & scales are non-trainable
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
if
self
.
is_static_input_scheme
:
input_scale
=
layer
.
input_scale
.
data
# For fused modules (QKV), take the max scale
if
input_scale
.
numel
()
!=
1
:
input_scale
=
input_scale
.
max
()
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
input_scale
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
if
self
.
use_aiter_kernel
:
return
self
.
_apply_aiter_kernel
(
layer
,
x
,
bias
)
else
:
return
self
.
_apply_emulation
(
layer
,
x
,
bias
)
def
_apply_aiter_kernel
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
M
=
x
.
shape
[
0
]
out_dtype
=
x
.
dtype
if
self
.
out_dtype
is
None
else
self
.
out_dtype
input_scale
=
layer
.
input_scale
x_fp8
=
(
x
/
input_scale
).
clamp
(
self
.
fp8_min
,
self
.
fp8_max
).
to
(
self
.
fp8_dtype
)
# Broadcast per-tensor scale to per-row (M, 1) for Aiter kernel
x_scales
=
input_scale
.
expand
(
M
,
1
).
to
(
dtype
=
torch
.
float32
,
device
=
x
.
device
)
y
=
rocm_aiter_ops
.
gemm_a8wfp4
(
x_fp8
,
layer
.
weight
,
x_scales
,
layer
.
weight_scale
,
out_dtype
)
if
bias
is
not
None
:
y
=
y
+
bias
return
y
def
_apply_emulation
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
(
dequant_mxfp4
,
)
weight_dq
=
dequant_mxfp4
(
layer
.
weight
,
layer
.
weight_scale
,
x
.
dtype
,
)
input_scale
=
layer
.
input_scale
x_fp8
=
(
x
/
input_scale
).
clamp
(
self
.
fp8_min
,
self
.
fp8_max
).
to
(
self
.
fp8_dtype
)
x_dq
=
(
x_fp8
.
to
(
x
.
dtype
)
*
input_scale
).
to
(
x
.
dtype
)
return
F
.
linear
(
x_dq
,
weight_dq
,
bias
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment