Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
084aa19f
Unverified
Commit
084aa19f
authored
Feb 08, 2026
by
danisereb
Committed by
GitHub
Feb 08, 2026
Browse files
Add support for ModelOpt MXFP8 dense models (#33786)
Signed-off-by:
Daniel Serebrenik
<
daserebrenik@nvidia.com
>
parent
1ecfabe5
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
375 additions
and
14 deletions
+375
-14
docs/features/quantization/modelopt.md
docs/features/quantization/modelopt.md
+1
-0
vllm/config/model.py
vllm/config/model.py
+1
-0
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+2
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-1
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+247
-2
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
+121
-11
No files found.
docs/features/quantization/modelopt.md
View file @
084aa19f
...
...
@@ -17,6 +17,7 @@ following `quantization.quant_algo` values:
-
`FP8_PER_CHANNEL_PER_TOKEN`
: per-channel weight scale and dynamic per-token activation quantization.
-
`FP8_PB_WO`
(ModelOpt may emit
`fp8_pb_wo`
): block-scaled FP8 weight-only (typically 128×128 blocks).
-
`NVFP4`
: ModelOpt NVFP4 checkpoints (use
`quantization="modelopt_fp4"`
).
-
`MXFP8`
: ModelOpt MXFP8 checkpoints (use
`quantization="modelopt_mxfp8"`
).
## Quantizing HuggingFace Models with PTQ
...
...
vllm/config/model.py
View file @
084aa19f
...
...
@@ -878,6 +878,7 @@ class ModelConfig:
"moe_wna16"
,
"modelopt"
,
"modelopt_fp4"
,
"modelopt_mxfp8"
,
"petit_nvfp4"
,
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
084aa19f
...
...
@@ -494,6 +494,7 @@ class FusedMoEQuantConfig:
"mxfp4"
,
"mxfp6_e3m2"
,
"mxfp6_e2m3"
,
"mxfp8"
,
}
assert
not
isinstance
(
weight_dtype
,
str
)
or
weight_dtype
in
{
"nvfp4"
,
...
...
@@ -501,6 +502,7 @@ class FusedMoEQuantConfig:
"mxfp6_e3m2"
,
"mxfp6_e2m3"
,
"int4"
,
"mxfp8"
,
}
if
weight_dtype
is
None
:
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
084aa19f
...
...
@@ -17,6 +17,7 @@ QuantizationMethods = Literal[
"fp_quant"
,
"modelopt"
,
"modelopt_fp4"
,
"modelopt_mxfp8"
,
"gguf"
,
"gptq_marlin"
,
"awq_marlin"
,
...
...
@@ -119,7 +120,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from
.gptq
import
GPTQConfig
from
.gptq_marlin
import
GPTQMarlinConfig
from
.inc
import
INCConfig
from
.modelopt
import
ModelOptFp8Config
,
ModelOptNvFp4Config
from
.modelopt
import
ModelOptFp8Config
,
ModelOptMxFp8Config
,
ModelOptNvFp4Config
from
.moe_wna16
import
MoeWNA16Config
from
.mxfp4
import
Mxfp4Config
from
.petit
import
PetitNvFp4Config
...
...
@@ -133,6 +134,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"fp_quant"
:
FPQuantConfig
,
"modelopt"
:
ModelOptFp8Config
,
"modelopt_fp4"
:
ModelOptNvFp4Config
,
"modelopt_mxfp8"
:
ModelOptMxFp8Config
,
"gguf"
:
GGUFConfig
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
084aa19f
...
...
@@ -63,6 +63,13 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
get_marlin_input_dtype
,
)
from
vllm.model_executor.layers.quantization.utils.mxfp8_utils
import
(
MXFP8_BLOCK_SIZE
,
MXFP8_SCALE_DTYPE
,
MXFP8_VALUE_DTYPE
,
Mxfp8LinearBackend
,
Mxfp8LinearOp
,
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
apply_nvfp4_linear
,
convert_to_nvfp4_linear_kernel_format
,
...
...
@@ -103,6 +110,8 @@ QUANT_ALGOS = [
"FP8_PB_WO"
,
# FP4
"NVFP4"
,
# MXFP8
"MXFP8"
,
]
KV_CACHE_QUANT_ALGOS
=
[
"FP8"
]
...
...
@@ -386,12 +395,12 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
quant_config
=
hf_quant_cfg
[
"quantization"
]
if
isinstance
(
quant_config
,
dict
):
quant_algo
=
str
(
quant_config
.
get
(
"quant_algo"
,
""
))
if
"FP8"
in
quant_algo
.
upper
():
if
quant_algo
.
upper
()
==
"FP8"
:
return
"modelopt"
else
:
# Check for compressed-tensors style config with specific quant_algo
quant_algo
=
str
(
hf_quant_cfg
.
get
(
"quant_algo"
,
""
))
if
"FP8"
in
quant_algo
.
upper
():
if
quant_algo
.
upper
()
==
"FP8"
:
return
"modelopt"
return
None
...
...
@@ -1547,3 +1556,239 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
ModelOptNvFp4Config
.
LinearMethodCls
=
ModelOptNvFp4LinearMethod
ModelOptNvFp4Config
.
FusedMoEMethodCls
=
ModelOptNvFp4FusedMoE
ModelOptNvFp4Config
.
KVCacheMethodCls
=
ModelOptFp8KVCacheMethod
class
ModelOptMxFp8Config
(
ModelOptQuantConfigBase
):
"""Config class for ModelOpt MXFP8."""
def
__init__
(
self
,
is_checkpoint_mxfp8_serialized
:
bool
,
kv_cache_quant_algo
:
str
|
None
,
exclude_modules
:
list
[
str
],
)
->
None
:
super
().
__init__
(
exclude_modules
)
self
.
is_checkpoint_mxfp8_serialized
=
is_checkpoint_mxfp8_serialized
if
not
is_checkpoint_mxfp8_serialized
:
raise
ValueError
(
"MXFP8 quantization requires a serialized checkpoint. "
"Dynamic quantization is not supported."
)
logger
.
warning
(
"Detected ModelOpt MXFP8 checkpoint. Please note that "
"the format is experimental and could change in future."
)
self
.
kv_cache_quant_algo
=
kv_cache_quant_algo
def
get_name
(
self
)
->
QuantizationMethods
:
return
"modelopt_mxfp8"
def
get_supported_act_dtypes
(
self
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# MXFP8 hardware acceleration requires Blackwell (SM100) or newer
return
100
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
"QuantizeMethodBase | None"
:
# MXFP8 does not yet support MoE models
if
isinstance
(
layer
,
FusedMoE
):
raise
NotImplementedError
(
"MXFP8 quantization does not yet support MoE models. "
"Please use FP8 or NVFP4 quantization for MoE models."
)
return
super
().
get_quant_method
(
layer
,
prefix
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
QuantizationMethods
|
None
:
"""Detect if this ModelOpt MXFP8 config should be used based on
quantization config."""
if
hf_quant_cfg
is
None
:
return
None
# Use the community standard 'quant_method'
quant_method
=
hf_quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
# Only proceed if the method is explicitly "modelopt"
if
quant_method
!=
"modelopt"
:
return
None
# Look for ModelOpt-specific config structure
if
"quantization"
in
hf_quant_cfg
:
quant_config
=
hf_quant_cfg
[
"quantization"
]
if
isinstance
(
quant_config
,
dict
):
quant_algo
=
str
(
quant_config
.
get
(
"quant_algo"
,
""
)).
upper
()
if
"MXFP8"
in
quant_algo
:
return
"modelopt_mxfp8"
else
:
# Check for compressed-tensors style config with specific quant_algo
quant_algo
=
str
(
hf_quant_cfg
.
get
(
"quant_algo"
,
""
)).
upper
()
if
"MXFP8"
in
quant_algo
:
return
"modelopt_mxfp8"
return
None
@
classmethod
def
_from_config
(
cls
,
*
,
quant_method
:
str
,
kv_cache_quant_method
:
str
|
None
,
exclude_modules
:
list
[
str
],
original_config
:
dict
[
str
,
Any
],
**
kwargs
:
Any
,
)
->
"ModelOptMxFp8Config"
:
is_checkpoint_mxfp8_serialized
=
"MXFP8"
in
quant_method
.
upper
()
# For MXFP8, validate required fields in the config
if
is_checkpoint_mxfp8_serialized
and
"quantization"
in
original_config
:
quant_config
=
original_config
[
"quantization"
]
required_fields
=
[
"kv_cache_quant_algo"
,
"exclude_modules"
]
missing_fields
=
[
field
for
field
in
required_fields
if
field
not
in
quant_config
]
if
missing_fields
:
raise
ValueError
(
f
"MXFP8 quantization requires the following fields in "
f
"hf_quant_config.json:
{
missing_fields
}
"
)
return
cls
(
is_checkpoint_mxfp8_serialized
,
kv_cache_quant_method
,
exclude_modules
,
)
class
ModelOptMxFp8LinearMethod
(
LinearMethodBase
):
"""Linear method for ModelOpt MXFP8 quantization."""
def
__init__
(
self
,
quant_config
:
ModelOptMxFp8Config
)
->
None
:
self
.
quant_config
=
quant_config
if
not
self
.
quant_config
.
is_checkpoint_mxfp8_serialized
:
raise
ValueError
(
"MXFP8 currently only supports serialized checkpoints. "
"Dynamic quantization is not supported."
)
backend
:
Mxfp8LinearBackend
=
Mxfp8LinearBackend
.
EMULATION
self
.
mxfp8_linear_op
=
Mxfp8LinearOp
(
backend
=
backend
)
logger
.
info_once
(
"Using %s backend for MXFP8 GEMM"
,
backend
.
value
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
if
not
self
.
quant_config
.
is_checkpoint_mxfp8_serialized
:
raise
ValueError
(
"MXFP8 quantization was selected, but checkpoint is not "
"MXFP8 serialized. Dynamic quantization is not supported."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
if
input_size_per_partition
%
MXFP8_BLOCK_SIZE
!=
0
:
raise
ValueError
(
f
"MXFP8 requires input dimension to be divisible by "
f
"
{
MXFP8_BLOCK_SIZE
}
, got
{
input_size_per_partition
}
"
)
# Weight tensor: FP8 E4M3 format
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
MXFP8_VALUE_DTYPE
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
# Weight scale tensor (E8M0 encoded as uint8), one scale per block of 32 along K
weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
MXFP8_BLOCK_SIZE
,
dtype
=
MXFP8_SCALE_DTYPE
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
layer
.
weight
.
ndim
!=
2
:
raise
ValueError
(
f
"MXFP8 weight must be 2D tensor [N, K], got
{
layer
.
weight
.
ndim
}
D "
f
"with shape
{
tuple
(
layer
.
weight
.
shape
)
}
"
)
if
layer
.
weight
.
dtype
!=
MXFP8_VALUE_DTYPE
:
raise
ValueError
(
f
"MXFP8 weight must be
{
MXFP8_VALUE_DTYPE
}
(FP8 E4M3), "
f
"got
{
layer
.
weight
.
dtype
}
. The checkpoint may not be properly "
f
"quantized with MXFP8."
)
weight
=
layer
.
weight
.
data
# [N, K]
N
,
K
=
weight
.
shape
scale_k
=
K
//
MXFP8_BLOCK_SIZE
# Slice weight_scale to match weight dimensions (handles padding)
weight_scale
=
layer
.
weight_scale
.
data
[:
N
,
:
scale_k
].
contiguous
()
layer
.
weight
=
Parameter
(
weight
.
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
if
layer
.
weight
.
dtype
!=
MXFP8_VALUE_DTYPE
:
raise
ValueError
(
f
"Weight dtype
{
layer
.
weight
.
dtype
}
!= expected
{
MXFP8_VALUE_DTYPE
}
"
)
if
layer
.
weight_scale
.
dtype
!=
MXFP8_SCALE_DTYPE
:
raise
ValueError
(
f
"Weight scale dtype
{
layer
.
weight_scale
.
dtype
}
!= "
f
"expected
{
MXFP8_SCALE_DTYPE
}
"
)
return
self
.
mxfp8_linear_op
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
)
# Register the method classes for ModelOptMxFp8Config
ModelOptMxFp8Config
.
LinearMethodCls
=
ModelOptMxFp8LinearMethod
ModelOptMxFp8Config
.
KVCacheMethodCls
=
ModelOptFp8KVCacheMethod
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
View file @
084aa19f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
Enum
import
torch
from
vllm.logger
import
init_logger
from
vllm.utils.torch_utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
def
mxfp8_e4m3_quantize
(
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
try
:
from
flashinfer
import
mxfp8_quantize
as
mxfp8_e4m3_quantize
except
ImportError
as
err
:
raise
ImportError
(
"The package `flashinfer` is required to do "
"
MX
-
FP8
quantization. Please install it with"
"`pip install flashinfer`"
)
from
err
class
Mxfp8LinearBackend
(
Enum
)
:
EMULATION
=
"emulation"
# MXFP8 constants
MXFP8_VALUE_DTYPE
=
torch
.
float8_e4m3fn
MXFP8
_SCALE_DTYPE
=
torch
.
uint8
MXFP8_BLOCK_SIZE
=
32
x_q
,
x_scales
=
mxfp8_e4m3_quantize
(
x
,
is_sf_swizzled_layout
=
False
)
if
x_scales
.
ndim
==
1
:
def
_mxfp8_e4m3_quantize_impl
(
x
:
torch
.
Tensor
,
is_sf_swizzled_layout
:
bool
=
False
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
flashinfer
import
mxfp8_quantize
as
flashinfer_mxfp8_quantize
x_q
,
x_scales
=
flashinfer_mxfp8_quantize
(
x
,
is_sf_swizzled_layout
=
is_sf_swizzled_layout
)
if
x_scales
.
ndim
==
1
and
x
.
ndim
==
2
and
not
is_sf_swizzled_layout
:
x_scales
=
x_scales
.
view
(
x
.
size
(
0
),
-
1
)
return
x_q
,
x_scales
def
mxfp8_e4m3_quantize
(
x
:
torch
.
Tensor
,
is_sf_swizzled_layout
:
bool
=
False
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
vllm
.
mxfp8_quantize
(
x
,
is_sf_swizzled_layout
)
def
dequant_mxfp8_to_bf16
(
x
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Dequantize MXFP8 tensor to BF16."""
x_float
=
x
.
to
(
torch
.
float32
)
num_blocks
=
x
.
shape
[
-
1
]
//
MXFP8_BLOCK_SIZE
x_blocked
=
x_float
.
view
(
*
x
.
shape
[:
-
1
],
num_blocks
,
MXFP8_BLOCK_SIZE
)
descale
=
torch
.
exp2
(
scales
.
to
(
torch
.
float32
)
-
127.0
)
dequantized
=
x_blocked
*
descale
.
unsqueeze
(
-
1
)
dequantized
=
dequantized
.
view
(
*
x
.
shape
)
return
dequantized
.
to
(
torch
.
bfloat16
)
def
mxfp8_e4m3_quantize_fake
(
x
:
torch
.
Tensor
,
is_sf_swizzled_layout
:
bool
=
False
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Fake implementation for torch.compile tracing."""
fp_data
=
torch
.
empty_like
(
x
,
dtype
=
MXFP8_VALUE_DTYPE
)
block_size
=
MXFP8_BLOCK_SIZE
if
x
.
ndim
==
2
:
M
,
N
=
x
.
shape
K
=
(
N
+
block_size
-
1
)
//
block_size
if
is_sf_swizzled_layout
:
M_padded
=
((
M
+
127
)
//
128
)
*
128
K_padded
=
((
K
+
3
)
//
4
)
*
4
scales
=
torch
.
empty
(
M_padded
*
K_padded
,
dtype
=
MXFP8_SCALE_DTYPE
,
device
=
x
.
device
)
else
:
scales
=
torch
.
empty
((
M
,
K
),
dtype
=
MXFP8_SCALE_DTYPE
,
device
=
x
.
device
)
elif
x
.
ndim
==
3
:
B
,
M
,
N
=
x
.
shape
K
=
(
N
+
block_size
-
1
)
//
block_size
if
is_sf_swizzled_layout
:
M_padded
=
((
M
+
127
)
//
128
)
*
128
K_padded
=
((
K
+
3
)
//
4
)
*
4
scales
=
torch
.
empty
(
B
*
M_padded
*
K_padded
,
dtype
=
MXFP8_SCALE_DTYPE
,
device
=
x
.
device
)
else
:
scales
=
torch
.
empty
((
B
,
M
,
K
),
dtype
=
MXFP8_SCALE_DTYPE
,
device
=
x
.
device
)
else
:
scale_shape
=
list
(
x
.
shape
)
scale_shape
[
-
1
]
=
(
x
.
shape
[
-
1
]
+
block_size
-
1
)
//
block_size
scales
=
torch
.
empty
(
scale_shape
,
dtype
=
MXFP8_SCALE_DTYPE
,
device
=
x
.
device
)
return
fp_data
,
scales
direct_register_custom_op
(
op_name
=
"mxfp8_quantize"
,
op_func
=
_mxfp8_e4m3_quantize_impl
,
fake_impl
=
mxfp8_e4m3_quantize_fake
,
)
class
Mxfp8LinearOp
:
def
__init__
(
self
,
backend
:
Mxfp8LinearBackend
):
if
backend
not
in
Mxfp8LinearBackend
:
raise
ValueError
(
f
"Unsupported backend:
{
backend
}
"
)
self
.
backend
=
backend
def
apply
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
# Validate weight_scale dtype and shape (must be 2D for TORCH backend)
if
weight_scale
.
dtype
!=
MXFP8_SCALE_DTYPE
:
raise
ValueError
(
f
"TORCH backend requires
{
MXFP8_SCALE_DTYPE
}
weight_scale dtype, "
f
"got
{
weight_scale
.
dtype
}
."
)
if
weight_scale
.
ndim
!=
2
:
raise
ValueError
(
f
"TORCH backend requires 2D weight_scale, got
{
weight_scale
.
ndim
}
D. "
f
"Ensure process_weights_after_loading was called."
)
weight_bf16
=
dequant_mxfp8_to_bf16
(
weight
,
weight_scale
)
output
=
torch
.
nn
.
functional
.
linear
(
input
,
weight_bf16
,
bias
)
return
output
.
to
(
out_dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment