Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
df1e30e7
Unverified
Commit
df1e30e7
authored
Apr 11, 2026
by
EdalatiAli
Committed by
GitHub
Apr 11, 2026
Browse files
[Quant] add CompressedTensorsW8A8Mxfp8 for linear and MoE layers (#38815)
Signed-off-by:
EdalatiAli
<
aliedalati@cohere.com
>
parent
bd8bd523
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
371 additions
and
0 deletions
+371
-0
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+36
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+25
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe.py
..._tensors/compressed_tensors_moe/compressed_tensors_moe.py
+7
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py
...mpressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py
+209
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+2
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_mxfp8.py
...mpressed_tensors/schemes/compressed_tensors_w8a8_mxfp8.py
+92
-0
No files found.
tests/quantization/test_compressed_tensors.py
View file @
df1e30e7
...
...
@@ -28,6 +28,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsW4A16Fp4
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A8Mxfp8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
,
)
...
...
@@ -632,3 +633,38 @@ def test_get_quant_method_returns_none_for_unmatched_parallel_lm_head():
assert
method
is
None
,
(
f
"Expected None for unmatched ParallelLMHead, got
{
type
(
method
).
__name__
}
"
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
()
or
not
current_platform
.
has_device_capability
(
75
),
reason
=
"MXFP8 requires Turing (sm_75+) or newer."
,
)
def
test_compressed_tensors_mxfp8_moe_setup
(
vllm_runner
):
"""Verify MXFP8 scheme, dtypes, and generation for a MoE model."""
model_path
=
"AliEdalati97/Qwen3-30B-A3B-MXFP8"
with
vllm_runner
(
model_path
,
enforce_eager
=
True
,
load_format
=
"dummy"
,
hf_overrides
=
{
"num_hidden_layers"
:
4
},
)
as
llm
:
def
check_model
(
model
):
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w8a8_mxfp8
import
(
# noqa: E501
CompressedTensorsW8A8Mxfp8MoEMethod
,
)
layer
=
model
.
model
.
layers
[
0
]
qkv
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv
.
scheme
,
CompressedTensorsW8A8Mxfp8
)
experts
=
layer
.
mlp
.
experts
assert
isinstance
(
experts
,
FusedMoE
)
assert
isinstance
(
experts
.
quant_method
,
CompressedTensorsW8A8Mxfp8MoEMethod
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
assert
output
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
df1e30e7
...
...
@@ -49,6 +49,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW4A16Mxfp4
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A8Mxfp8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
,
)
...
...
@@ -403,6 +404,27 @@ class CompressedTensorsConfig(QuantizationConfig):
and
is_symmetric
)
@
staticmethod
def
_is_mxfp8
(
quant_args
:
QuantizationArgs
)
->
bool
:
if
quant_args
is
None
:
return
False
is_group_quant
=
quant_args
.
strategy
==
QuantizationStrategy
.
GROUP
.
value
is_symmetric
=
quant_args
.
symmetric
is_group_size_32
=
quant_args
.
group_size
==
32
is_float_type
=
quant_args
.
type
==
QuantizationType
.
FLOAT
is_8_bits
=
quant_args
.
num_bits
==
8
is_mxfp8_scale_dtype
=
quant_args
.
scale_dtype
==
torch
.
uint8
return
(
is_group_quant
and
is_float_type
and
is_8_bits
and
is_group_size_32
and
is_symmetric
and
is_mxfp8_scale_dtype
)
@
staticmethod
def
_is_static_tensor_w8a8
(
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
...
...
@@ -606,6 +628,9 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_mxfp4
(
weight_quant
):
return
CompressedTensorsW4A16Mxfp4
()
if
self
.
_is_mxfp8
(
weight_quant
):
return
CompressedTensorsW8A8Mxfp8
()
if
self
.
_is_fp8_w4a8_sm90
(
weight_quant
,
input_quant
):
return
CompressedTensorsW4A8Fp8
(
num_bits
=
weight_quant
.
num_bits
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe.py
View file @
df1e30e7
...
...
@@ -68,6 +68,13 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return
CompressedTensorsW4A4Mxfp4MoEMethod
(
layer
.
moe_config
)
if
quant_config
.
_is_mxfp8
(
weight_quant
):
from
.compressed_tensors_moe_w8a8_mxfp8
import
(
CompressedTensorsW8A8Mxfp8MoEMethod
,
)
return
CompressedTensorsW8A8Mxfp8MoEMethod
(
layer
.
moe_config
)
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
# group_size=None means channelwise
group_size
=
weight_quant
.
group_size
or
-
1
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py
0 → 100644
View file @
df1e30e7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.oracle.fp8
import
(
convert_to_fp8_moe_kernel_format
,
make_fp8_moe_kernel
,
make_fp8_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.oracle.mxfp8
import
(
select_mxfp8_moe_backend
,
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe
import
(
# noqa: E501
CompressedTensorsMoEMethod
,
)
from
vllm.model_executor.layers.quantization.utils.mxfp8_utils
import
(
MXFP8_BLOCK_SIZE
,
MXFP8_SCALE_DTYPE
,
MXFP8_VALUE_DTYPE
,
)
from
vllm.model_executor.utils
import
replace_parameter
,
set_weight_attrs
class
CompressedTensorsW8A8Mxfp8MoEMethod
(
CompressedTensorsMoEMethod
):
"""Compressed-tensors MoE method for pre-quantized MXFP8 (W8A8) checkpoints.
Loads FP8 (E4M3) weights with E8M0 uint8 per-group scales (group_size=32)
from checkpoint. Activations are dynamically quantized to MXFP8 at runtime.
Supports FlashInfer TRT-LLM and Marlin backends (auto-selected).
"""
def
__init__
(
self
,
moe
:
FusedMoEConfig
):
super
().
__init__
(
moe
)
self
.
weight_block_size
=
[
1
,
MXFP8_BLOCK_SIZE
]
self
.
fp8_backend
,
self
.
experts_cls
=
select_mxfp8_moe_backend
(
config
=
self
.
moe
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
num_experts
=
num_experts
layer
.
params_dtype
=
params_dtype
w13_num_shards
=
2
if
self
.
moe
.
is_act_and_mul
else
1
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
w13_num_shards
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
MXFP8_VALUE_DTYPE
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
MXFP8_VALUE_DTYPE
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
w13_num_shards
*
intermediate_size_per_partition
,
hidden_size
//
MXFP8_BLOCK_SIZE
,
dtype
=
MXFP8_SCALE_DTYPE
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
GROUP
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
//
MXFP8_BLOCK_SIZE
,
dtype
=
MXFP8_SCALE_DTYPE
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
FusedMoE
)
->
None
:
layer
.
weight_block_size
=
self
.
weight_block_size
w13
,
w2
,
w13_scale
,
w2_scale
=
convert_to_fp8_moe_kernel_format
(
fp8_backend
=
self
.
fp8_backend
,
layer
=
layer
,
w13
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w13_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
w13_input_scale
=
layer
.
w13_input_scale
,
w2_input_scale
=
layer
.
w2_input_scale
,
)
replace_parameter
(
layer
,
"w13_weight"
,
w13
)
replace_parameter
(
layer
,
"w2_weight"
,
w2
)
replace_parameter
(
layer
,
"w13_weight_scale"
,
w13_scale
)
replace_parameter
(
layer
,
"w2_weight_scale"
,
w2_scale
)
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
if
self
.
moe_quant_config
is
not
None
:
assert
self
.
experts_cls
is
not
None
self
.
moe_kernel
=
make_fp8_moe_kernel
(
moe_quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
fp8_backend
=
self
.
fp8_backend
,
experts_cls
=
self
.
experts_cls
,
routing_tables
=
layer
.
_maybe_init_expert_routing_tables
(),
shared_experts
=
layer
.
shared_experts
,
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
return
make_fp8_moe_quant_config
(
fp8_backend
=
self
.
fp8_backend
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
weight_block_size
,
)
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalizeModular
|
None
:
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
uses the new modular kernel "
"initialization logic. This function should not be called."
)
def
apply_monolithic
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply_monolithic
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
activation
=
layer
.
activation
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
num_expert_group
=
layer
.
num_expert_group
,
topk_group
=
layer
.
topk_group
,
e_score_correction_bias
=
layer
.
e_score_correction_bias
,
routed_scaling_factor
=
layer
.
routed_scaling_factor
,
)
def
apply
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
assert
not
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
activation
=
layer
.
activation
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
shared_experts_input
=
shared_experts_input
,
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
df1e30e7
...
...
@@ -9,6 +9,7 @@ from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4
from
.compressed_tensors_w4a16_nvfp4
import
CompressedTensorsW4A16Fp4
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
from
.compressed_tensors_w8a8_int8
import
CompressedTensorsW8A8Int8
from
.compressed_tensors_w8a8_mxfp8
import
CompressedTensorsW8A8Mxfp8
from
.compressed_tensors_w8a16_fp8
import
CompressedTensorsW8A16Fp8
from
.compressed_tensors_wNa16
import
WNA16_SUPPORTED_BITS
,
CompressedTensorsWNA16
...
...
@@ -28,4 +29,5 @@ __all__ = [
"CompressedTensorsW4A4Fp4"
,
"CompressedTensorsW4A8Int"
,
"CompressedTensorsW4A8Fp8"
,
"CompressedTensorsW8A8Mxfp8"
,
]
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_mxfp8.py
0 → 100644
View file @
df1e30e7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
import
torch
from
vllm.model_executor.kernels.linear
import
init_mxfp8_linear_kernel
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
)
from
vllm.model_executor.layers.quantization.utils.mxfp8_utils
import
(
MXFP8_BLOCK_SIZE
,
MXFP8_SCALE_DTYPE
,
MXFP8_VALUE_DTYPE
,
)
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
ModelWeightParameter
,
)
__all__
=
[
"CompressedTensorsW8A8Mxfp8"
]
class
CompressedTensorsW8A8Mxfp8
(
CompressedTensorsScheme
):
"""
Compressed tensors scheme for MXFP8 quantization (W8A8).
Loads pre-quantized MXFP8 weights from compressed-tensors checkpoints.
Activations are dynamically quantized to MXFP8 at runtime.
MXFP8 format:
- 8-bit float weights (E4M3) stored as float8_e4m3fn
- Per-group E8M0 scales (uint8) with group_size=32
- Activations dynamically quantized to MXFP8 during inference
"""
def
__init__
(
self
):
self
.
kernel
=
init_mxfp8_linear_kernel
()
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
list
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
,
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
params_dtype
=
params_dtype
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
MXFP8_VALUE_DTYPE
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
weight_scale
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
MXFP8_BLOCK_SIZE
,
dtype
=
MXFP8_SCALE_DTYPE
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
self
.
kernel
.
process_weights_after_loading
(
layer
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
return
self
.
kernel
.
apply_weights
(
layer
,
x
,
bias
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment