Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0a6a3a12
Unverified
Commit
0a6a3a12
authored
Mar 08, 2026
by
danisereb
Committed by
GitHub
Mar 08, 2026
Browse files
Add support for ModelOpt MXFP8 MoE models (#35986)
Signed-off-by:
Daniel Serebrenik
<
daserebrenik@nvidia.com
>
parent
4497431d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
597 additions
and
18 deletions
+597
-18
tests/kernels/moe/test_ocp_mx_moe.py
tests/kernels/moe/test_ocp_mx_moe.py
+185
-2
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+9
-0
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
+44
-0
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+359
-16
No files found.
tests/kernels/moe/test_ocp_mx_moe.py
View file @
0a6a3a12
...
@@ -20,6 +20,8 @@ TRTLLM_GEN_MXFP4_AVAILABLE = (
...
@@ -20,6 +20,8 @@ TRTLLM_GEN_MXFP4_AVAILABLE = (
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability_family
(
100
)
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability_family
(
100
)
)
)
TRTLLM_GEN_MXFP8_AVAILABLE
=
TRTLLM_GEN_MXFP4_AVAILABLE
HOPPER_MXFP4_BF16_AVAILABLE
=
(
HOPPER_MXFP4_BF16_AVAILABLE
=
(
current_platform
.
is_cuda
()
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
90
)
and
current_platform
.
is_device_capability
(
90
)
...
@@ -34,9 +36,15 @@ if TRTLLM_GEN_MXFP4_AVAILABLE:
...
@@ -34,9 +36,15 @@ if TRTLLM_GEN_MXFP4_AVAILABLE:
shuffle_matrix_a
,
shuffle_matrix_a
,
shuffle_matrix_sf_a
,
shuffle_matrix_sf_a
,
trtllm_fp4_block_scale_moe
,
trtllm_fp4_block_scale_moe
,
trtllm_fp8_block_scale_moe
,
)
)
from
flashinfer.fp4_quantization
import
nvfp4_block_scale_interleave
from
flashinfer.fp4_quantization
import
nvfp4_block_scale_interleave
from
flashinfer.fused_moe.core
import
get_w2_permute_indices_with_cache
if
TRTLLM_GEN_MXFP8_AVAILABLE
:
from
flashinfer.fused_moe.core
import
(
Fp8QuantizationType
,
get_w2_permute_indices_with_cache
,
)
@
dataclass
@
dataclass
...
@@ -160,6 +168,7 @@ def reference_moe(
...
@@ -160,6 +168,7 @@ def reference_moe(
beta
,
beta
,
limit
,
limit
,
act_type
,
act_type
,
is_gated
,
):
):
# renormalize routing
# renormalize routing
experts
=
torch
.
topk
(
roouting_logits
,
k
=
topk
,
dim
=-
1
,
sorted
=
True
)
experts
=
torch
.
topk
(
roouting_logits
,
k
=
topk
,
dim
=-
1
,
sorted
=
True
)
...
@@ -170,7 +179,12 @@ def reference_moe(
...
@@ -170,7 +179,12 @@ def reference_moe(
mlp1_weight
=
w13
[
expert_indices
,
...]
mlp1_weight
=
w13
[
expert_indices
,
...]
mlp1_bias
=
bias13
[
expert_indices
,
...]
mlp1_bias
=
bias13
[
expert_indices
,
...]
t
=
torch
.
einsum
(
"beck,bk->bec"
,
mlp1_weight
,
t
)
+
mlp1_bias
t
=
torch
.
einsum
(
"beck,bk->bec"
,
mlp1_weight
,
t
)
+
mlp1_bias
t
=
swiglu
(
t
,
alpha
=
alpha
,
beta
=
beta
,
limit
=
limit
)
if
is_gated
:
t
=
swiglu
(
t
,
alpha
=
alpha
,
beta
=
beta
,
limit
=
limit
)
else
:
# RELU2_NO_MUL: relu(x)^2
t
=
torch
.
relu
(
t
)
t
=
t
*
t
if
act_type
==
"mxfp8"
:
if
act_type
==
"mxfp8"
:
t_quantized
,
t_scale
=
mxfp8_quantize
(
t_quantized
,
t_scale
=
mxfp8_quantize
(
...
@@ -569,6 +583,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
...
@@ -569,6 +583,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
beta
,
beta
,
limit
,
limit
,
act_type
,
act_type
,
is_gated
=
True
,
)
)
ref_result
[
start_idx
:
end_idx
].
copy_
(
chunk_result
)
ref_result
[
start_idx
:
end_idx
].
copy_
(
chunk_result
)
...
@@ -705,6 +720,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
...
@@ -705,6 +720,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
beta
,
beta
,
limit
,
limit
,
"bf16"
,
"bf16"
,
is_gated
=
True
,
)
)
from
vllm.utils.flashinfer
import
flashinfer_cutlass_fused_moe
from
vllm.utils.flashinfer
import
flashinfer_cutlass_fused_moe
...
@@ -890,6 +906,7 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
...
@@ -890,6 +906,7 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
beta
,
beta
,
limit
,
limit
,
"mxfp8"
,
"mxfp8"
,
is_gated
=
True
,
)
)
# Prepare inputs for FlashInfer CUTLASS fused MoE
# Prepare inputs for FlashInfer CUTLASS fused MoE
...
@@ -965,3 +982,169 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
...
@@ -965,3 +982,169 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
# Allow some mismatch due to MXFP4 quantization
# Allow some mismatch due to MXFP4 quantization
check_accuracy
(
ref
,
out
,
atol
=
0
,
rtol
=
0.3
,
percent
=
0.8
)
check_accuracy
(
ref
,
out
,
atol
=
0
,
rtol
=
0.3
,
percent
=
0.8
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
128
])
@
pytest
.
mark
.
parametrize
(
"intermediate_size,hidden_size"
,
[(
3072
,
3072
)])
@
pytest
.
mark
.
parametrize
(
"is_gated"
,
[
True
],
ids
=
[
"gated"
])
@
pytest
.
mark
.
skipif
(
not
TRTLLM_GEN_MXFP8_AVAILABLE
,
reason
=
"nvidia gpu and compute capability sm100 is required for this test"
,
)
def
test_trtllm_gen_mxfp8_block_scale_moe
(
topk
:
int
,
num_experts
:
int
,
num_tokens
:
int
,
intermediate_size
:
int
,
hidden_size
:
int
,
is_gated
:
bool
,
):
torch
.
manual_seed
(
42
)
device
=
"cuda:0"
inter_size
=
intermediate_size
*
(
2
if
is_gated
else
1
)
hidden_states
=
(
torch
.
randn
(
num_tokens
,
hidden_size
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
/
20
)
w13
=
(
torch
.
randn
(
num_experts
,
inter_size
,
hidden_size
,
device
=
device
,
dtype
=
torch
.
bfloat16
,
)
/
20
)
w2
=
(
torch
.
randn
(
num_experts
,
hidden_size
,
intermediate_size
,
device
=
device
,
dtype
=
torch
.
bfloat16
,
)
/
20
)
router_logits
=
torch
.
rand
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
router_logits_kernel
=
router_logits
.
to
(
torch
.
bfloat16
)
# Quantize weights to MXFP8 and normalize scales to [E, M, K//32].
w13_q
,
w13_scale
=
mxfp8_quantize
(
w13
,
is_sf_swizzled_layout
=
False
)
w2_q
,
w2_scale
=
mxfp8_quantize
(
w2
,
is_sf_swizzled_layout
=
False
)
if
w13_scale
.
ndim
==
1
:
w13_scale
=
w13_scale
.
view
(
num_experts
,
inter_size
,
hidden_size
//
32
,
)
if
w2_scale
.
ndim
==
1
:
w2_scale
=
w2_scale
.
view
(
num_experts
,
hidden_size
,
intermediate_size
//
32
)
# Quantize activations to MXFP8.
hidden_states_q
,
hidden_states_scale
=
mxfp8_quantize
(
hidden_states
,
is_sf_swizzled_layout
=
False
)
if
hidden_states_scale
.
ndim
==
1
:
hidden_states_scale
=
hidden_states_scale
.
view
(
num_tokens
,
hidden_size
//
32
)
# Reference output using dequantized tensors + MXFP8 intermediate quantization.
w13_ref
=
mxfp8_dequantize
(
w13_q
,
w13_scale
).
to
(
torch
.
float32
)
w2_ref
=
mxfp8_dequantize
(
w2_q
,
w2_scale
).
to
(
torch
.
float32
)
hidden_states_ref
=
mxfp8_dequantize
(
hidden_states_q
,
hidden_states_scale
).
to
(
torch
.
float32
)
bias13
=
torch
.
zeros
(
num_experts
,
intermediate_size
*
(
2
if
is_gated
else
1
),
device
=
device
,
)
bias2
=
torch
.
zeros
(
num_experts
,
hidden_size
,
device
=
device
)
ref
=
reference_moe
(
router_logits_kernel
.
to
(
torch
.
float32
),
topk
,
num_experts
,
hidden_states_ref
,
w13_ref
,
bias13
,
w2_ref
,
bias2
,
alpha
=
1.0
,
beta
=
0.0
,
limit
=
None
,
act_type
=
"mxfp8"
,
is_gated
=
is_gated
,
)
# Shuffle weights/scales with the same indexed layout used by TRTLLM kernels.
epilogue_tile_m
=
128
gemm1_weights_shuffled
=
[]
gemm1_scales_shuffled
=
[]
gemm2_weights_shuffled
=
[]
gemm2_scales_shuffled
=
[]
for
i
in
range
(
num_experts
):
w13_rows
=
intermediate_size
*
(
2
if
is_gated
else
1
)
w13_interleaved
=
w13_q
[
i
].
clone
().
reshape
(
w13_rows
,
-
1
)
w13_scale_interleaved
=
w13_scale
[
i
].
clone
().
reshape
(
w13_rows
,
-
1
)
if
is_gated
:
w13_interleaved
=
reorder_rows_for_gated_act_gemm
(
w13_interleaved
)
w13_scale_interleaved
=
reorder_rows_for_gated_act_gemm
(
w13_scale_interleaved
)
gemm1_weights_shuffled
.
append
(
shuffle_matrix_a
(
w13_interleaved
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
.
contiguous
()
.
view
(
w13_q
.
dtype
)
)
gemm2_weights_shuffled
.
append
(
shuffle_matrix_a
(
w2_q
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
.
contiguous
()
.
view
(
w2_q
.
dtype
)
)
gemm1_scales_shuffled
.
append
(
shuffle_matrix_sf_a
(
w13_scale_interleaved
.
view
(
torch
.
uint8
).
reshape
(
w13_rows
,
-
1
),
epilogue_tile_m
,
)
.
contiguous
()
.
view
(
w13_scale
.
dtype
)
)
gemm2_scales_shuffled
.
append
(
shuffle_matrix_sf_a
(
w2_scale
[
i
].
view
(
torch
.
uint8
).
reshape
(
hidden_size
,
-
1
),
epilogue_tile_m
)
.
contiguous
()
.
view
(
w2_scale
.
dtype
)
)
out
=
trtllm_fp8_block_scale_moe
(
routing_logits
=
router_logits_kernel
,
routing_bias
=
None
,
hidden_states
=
hidden_states_q
,
hidden_states_scale
=
hidden_states_scale
,
gemm1_weights
=
torch
.
stack
(
gemm1_weights_shuffled
),
gemm1_weights_scale
=
torch
.
stack
(
gemm1_scales_shuffled
),
gemm2_weights
=
torch
.
stack
(
gemm2_weights_shuffled
),
gemm2_weights_scale
=
torch
.
stack
(
gemm2_scales_shuffled
),
num_experts
=
num_experts
,
top_k
=
topk
,
n_group
=
None
,
topk_group
=
None
,
intermediate_size
=
intermediate_size
,
local_expert_offset
=
0
,
local_num_experts
=
num_experts
,
routed_scaling_factor
=
None
,
routing_method_type
=
1
,
# renormalize routing
use_shuffled_weight
=
True
,
weight_layout
=
0
,
# MajorK
fp8_quantization_type
=
Fp8QuantizationType
.
MxFp8
,
)
# Block-scale MXFP8 kernels are approximate; require majority close.
check_accuracy
(
ref
,
out
,
atol
=
0.1
,
rtol
=
0.85
,
percent
=
0.8
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
0a6a3a12
...
@@ -1204,17 +1204,26 @@ class FusedMoE(CustomOp):
...
@@ -1204,17 +1204,26 @@ class FusedMoE(CustomOp):
# Determine per-tensor weight scale patterns based on variant
# Determine per-tensor weight scale patterns based on variant
# Use the dedicated method instead of brittle string matching
# Use the dedicated method instead of brittle string matching
uses_weight_scale_2
=
self
.
quant_method
.
uses_weight_scale_2_pattern
()
uses_weight_scale_2
=
self
.
quant_method
.
uses_weight_scale_2_pattern
()
quant_method
=
getattr
(
param
,
"quant_method"
,
None
)
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
# weights scales.
# weights scales.
# Input scales are always per-tensor.
# Input scales are always per-tensor.
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
# "weight_scale" for per-tensor scales.
# "weight_scale" for per-tensor scales.
# NOTE: ModelOpt MXFP8 MoE uses block scales in weight_scale
# tensors (quant_method=BLOCK), so those must not be treated
# as per-tensor scalars here.
is_block_weight_scale
=
(
"weight_scale"
in
weight_name
and
quant_method
==
FusedMoeWeightScaleSupported
.
BLOCK
.
value
)
is_per_tensor
=
(
is_per_tensor
=
(
"weight_scale_2"
in
weight_name
"weight_scale_2"
in
weight_name
if
uses_weight_scale_2
if
uses_weight_scale_2
else
"weight_scale"
in
weight_name
else
"weight_scale"
in
weight_name
)
or
"input_scale"
in
weight_name
)
or
"input_scale"
in
weight_name
is_per_tensor
=
is_per_tensor
and
not
is_block_weight_scale
if
is_per_tensor
:
if
is_per_tensor
:
self
.
_load_per_tensor_weight_scale
(
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
...
...
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
0 → 100644
View file @
0a6a3a12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
Enum
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEConfig
logger
=
init_logger
(
__name__
)
class
MxFp8MoeBackend
(
Enum
):
FLASHINFER_TRTLLM
=
"FLASHINFER_TRTLLM"
def
select_mxfp8_moe_backend
(
config
:
FusedMoEConfig
,
)
->
MxFp8MoeBackend
:
if
config
.
is_lora_enabled
:
raise
NotImplementedError
(
"LoRA is not supported for MXFP8 MoE."
)
AVAILABLE_BACKENDS
=
[
MxFp8MoeBackend
.
FLASHINFER_TRTLLM
,
]
runner_backend
=
config
.
moe_backend
if
runner_backend
!=
"auto"
:
mapping
=
{
"flashinfer_trtllm"
:
MxFp8MoeBackend
.
FLASHINFER_TRTLLM
,
}
if
backend
:
=
mapping
.
get
(
runner_backend
):
logger
.
info_once
(
"Using '%s' MxFp8 MoE backend (user-requested)."
,
backend
.
value
,
)
return
backend
raise
ValueError
(
f
"moe_backend='
{
runner_backend
}
' is not supported for MXFP8 MoE. "
f
"Expected one of
{
list
(
mapping
.
keys
())
}
."
)
# Auto-select: only one backend available for now.
backend
=
AVAILABLE_BACKENDS
[
0
]
logger
.
info_once
(
"Using '%s' MxFp8 MoE backend."
,
backend
.
value
)
return
backend
vllm/model_executor/layers/quantization/modelopt.py
View file @
0a6a3a12
...
@@ -9,17 +9,19 @@ from torch.nn.parameter import Parameter
...
@@ -9,17 +9,19 @@ from torch.nn.parameter import Parameter
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.kernels.linear
import
(
from
vllm.model_executor.kernels.linear
import
init_fp8_linear_kernel
init_fp8_linear_kernel
,
)
from
vllm.model_executor.layers.attention
import
Attention
,
MLAAttention
from
vllm.model_executor.layers.attention
import
Attention
,
MLAAttention
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEConfig
,
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe_method_base
import
(
FusedMoEMethodBase
,
)
)
from
vllm.model_executor.layers.fused_moe.layer
import
(
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
FusedMoeWeightScaleSupported
,
)
)
from
vllm.model_executor.layers.fused_moe.oracle.fp8
import
(
from
vllm.model_executor.layers.fused_moe.oracle.fp8
import
(
...
@@ -28,6 +30,10 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
...
@@ -28,6 +30,10 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
make_fp8_moe_quant_config
,
make_fp8_moe_quant_config
,
select_fp8_moe_backend
,
select_fp8_moe_backend
,
)
)
from
vllm.model_executor.layers.fused_moe.oracle.mxfp8
import
(
MxFp8MoeBackend
,
select_mxfp8_moe_backend
,
)
from
vllm.model_executor.layers.fused_moe.oracle.nvfp4
import
(
from
vllm.model_executor.layers.fused_moe.oracle.nvfp4
import
(
convert_to_nvfp4_moe_kernel_format
,
convert_to_nvfp4_moe_kernel_format
,
is_global_sf_supported_for_nvfp4_backend
,
is_global_sf_supported_for_nvfp4_backend
,
...
@@ -46,6 +52,9 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -46,6 +52,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
swap_w13_to_w31
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
W8A8BlockFp8LinearOp
,
W8A8BlockFp8LinearOp
,
process_fp8_input_tensor_strategy_moe
,
process_fp8_input_tensor_strategy_moe
,
...
@@ -60,6 +69,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
...
@@ -60,6 +69,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_VALUE_DTYPE
,
MXFP8_VALUE_DTYPE
,
Mxfp8LinearBackend
,
Mxfp8LinearBackend
,
Mxfp8LinearOp
,
Mxfp8LinearOp
,
mxfp8_e4m3_quantize
,
swizzle_mxfp8_scale
,
swizzle_mxfp8_scale
,
)
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
...
@@ -86,7 +96,8 @@ from vllm.model_executor.parameter import (
...
@@ -86,7 +96,8 @@ from vllm.model_executor.parameter import (
ModelWeightParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
,
PerTensorScaleParameter
,
)
)
from
vllm.model_executor.utils
import
replace_parameter
from
vllm.model_executor.utils
import
replace_parameter
,
set_weight_attrs
from
vllm.utils.flashinfer
import
flashinfer_trtllm_fp8_block_scale_moe
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.model_executor.models.utils
import
WeightsMapper
from
vllm.model_executor.models.utils
import
WeightsMapper
...
@@ -1487,17 +1498,6 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
...
@@ -1487,17 +1498,6 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
# MXFP8 hardware acceleration requires Blackwell (SM100) or newer
# MXFP8 hardware acceleration requires Blackwell (SM100) or newer
return
100
return
100
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
"QuantizeMethodBase | None"
:
# MXFP8 does not yet support MoE models
if
isinstance
(
layer
,
FusedMoE
):
raise
NotImplementedError
(
"MXFP8 quantization does not yet support MoE models. "
"Please use FP8 or NVFP4 quantization for MoE models."
)
return
super
().
get_quant_method
(
layer
,
prefix
)
@
classmethod
@
classmethod
def
override_quantization_method
(
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
cls
,
hf_quant_cfg
,
user_quant
...
@@ -1699,8 +1699,351 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
...
@@ -1699,8 +1699,351 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
)
)
class
ModelOptMxFp8FusedMoE
(
FusedMoEMethodBase
):
"""FlashInfer TRTLLM MXFP8 block-scale MoE for ModelOpt checkpoints."""
def
__init__
(
self
,
quant_config
:
ModelOptMxFp8Config
,
moe_config
:
FusedMoEConfig
,
)
->
None
:
super
().
__init__
(
moe_config
)
self
.
quant_config
=
quant_config
assert
self
.
quant_config
.
is_checkpoint_mxfp8_serialized
# Select MXFP8 MoE backend
self
.
mxfp8_backend
=
select_mxfp8_moe_backend
(
self
.
moe
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
hidden_size
=
hidden_size
layer
.
orig_dtype
=
params_dtype
if
hidden_size
%
MXFP8_BLOCK_SIZE
!=
0
:
raise
ValueError
(
f
"MXFP8 MoE requires hidden_size divisible by
{
MXFP8_BLOCK_SIZE
}
, "
f
"got
{
hidden_size
}
."
)
if
intermediate_size_per_partition
%
MXFP8_BLOCK_SIZE
!=
0
:
raise
ValueError
(
"MXFP8 MoE requires intermediate_size_per_partition divisible by "
f
"
{
MXFP8_BLOCK_SIZE
}
, got
{
intermediate_size_per_partition
}
."
)
layer
.
num_experts
=
num_experts
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
w13_num_shards
=
2
if
self
.
moe
.
is_act_and_mul
else
1
# GEMM 1 weights: [E, (2I or I), H]
w13_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
w13_num_shards
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
MXFP8_VALUE_DTYPE
,
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
# GEMM 2 weights: [E, H, I]
w2_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
MXFP8_VALUE_DTYPE
,
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
# Per-block (K=32) E8M0 scales.
w13_weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
w13_num_shards
*
intermediate_size_per_partition
,
hidden_size
//
MXFP8_BLOCK_SIZE
,
dtype
=
MXFP8_SCALE_DTYPE
,
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
//
MXFP8_BLOCK_SIZE
,
dtype
=
MXFP8_SCALE_DTYPE
,
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Ensure the generic MoE weight-loader treats these as block scales.
set_weight_attrs
(
layer
.
w13_weight_scale
,
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
BLOCK
.
value
},
)
set_weight_attrs
(
layer
.
w2_weight_scale
,
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
BLOCK
.
value
},
)
@
staticmethod
def
_check_weight_dtypes
(
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Validate weight and scale dtypes before processing."""
expected
=
{
"w13_weight"
:
MXFP8_VALUE_DTYPE
,
"w2_weight"
:
MXFP8_VALUE_DTYPE
,
"w13_weight_scale"
:
MXFP8_SCALE_DTYPE
,
"w2_weight_scale"
:
MXFP8_SCALE_DTYPE
,
}
for
name
,
expected_dtype
in
expected
.
items
():
actual
=
getattr
(
layer
,
name
).
dtype
if
actual
!=
expected_dtype
:
raise
ValueError
(
f
"Expected
{
name
}
dtype
{
expected_dtype
}
, got
{
actual
}
."
)
def
_shuffle_weights_for_trtllm
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Shuffle weights and scales into FlashInfer TRTLLM MXFP8 layout."""
from
flashinfer
import
(
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
,
shuffle_matrix_sf_a
,
)
epilogue_tile_m
=
128
num_experts
=
layer
.
w13_weight
.
shape
[
0
]
is_gated
=
self
.
moe
.
is_act_and_mul
intermediate_size_factor
=
2
if
is_gated
else
1
w13_weight
=
layer
.
w13_weight
.
data
w13_scale
=
layer
.
w13_weight_scale
.
data
if
is_gated
:
# FI TRTLLM gated kernels use W31 ordering. Model checkpoints store
# gated projection as W13, so convert once before shuffling.
w13_weight
=
swap_w13_to_w31
(
w13_weight
)
w13_scale
=
swap_w13_to_w31
(
w13_scale
)
w13_weight_shuffled
=
[]
w2_weight_shuffled
=
[]
w13_scale_shuffled
=
[]
w2_scale_shuffled
=
[]
for
i
in
range
(
num_experts
):
w13_i
=
w13_weight
[
i
].
reshape
(
intermediate_size_factor
*
layer
.
intermediate_size_per_partition
,
-
1
)
w13_sf_i
=
w13_scale
[
i
].
reshape
(
intermediate_size_factor
*
layer
.
intermediate_size_per_partition
,
-
1
)
if
is_gated
:
# Reorder rows for gated activation layout expected by TRTLLM.
w13_i
=
reorder_rows_for_gated_act_gemm
(
w13_i
.
clone
())
w13_sf_i
=
reorder_rows_for_gated_act_gemm
(
w13_sf_i
.
clone
())
w13_shuffled_i
=
shuffle_matrix_a
(
w13_i
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
w2_shuffled_i
=
shuffle_matrix_a
(
layer
.
w2_weight
.
data
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
w13_weight_shuffled
.
append
(
w13_shuffled_i
.
contiguous
().
view
(
MXFP8_VALUE_DTYPE
)
)
w2_weight_shuffled
.
append
(
w2_shuffled_i
.
contiguous
().
view
(
MXFP8_VALUE_DTYPE
)
)
w13_sf_shuffled_i
=
shuffle_matrix_sf_a
(
w13_sf_i
.
view
(
torch
.
uint8
).
reshape
(
intermediate_size_factor
*
layer
.
intermediate_size_per_partition
,
-
1
,
),
epilogue_tile_m
,
)
w2_sf_shuffled_i
=
shuffle_matrix_sf_a
(
layer
.
w2_weight_scale
.
data
[
i
]
.
view
(
torch
.
uint8
)
.
reshape
(
layer
.
hidden_size
,
-
1
),
epilogue_tile_m
,
)
w13_scale_shuffled
.
append
(
w13_sf_shuffled_i
.
contiguous
().
view
(
MXFP8_SCALE_DTYPE
)
)
w2_scale_shuffled
.
append
(
w2_sf_shuffled_i
.
contiguous
().
view
(
MXFP8_SCALE_DTYPE
)
)
replace_parameter
(
layer
,
"w13_weight"
,
torch
.
stack
(
w13_weight_shuffled
).
contiguous
()
)
replace_parameter
(
layer
,
"w2_weight"
,
torch
.
stack
(
w2_weight_shuffled
).
contiguous
()
)
replace_parameter
(
layer
,
"w13_weight_scale"
,
torch
.
stack
(
w13_scale_shuffled
).
contiguous
(),
)
replace_parameter
(
layer
,
"w2_weight_scale"
,
torch
.
stack
(
w2_scale_shuffled
).
contiguous
(),
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
self
.
_check_weight_dtypes
(
layer
)
self
.
_shuffle_weights_for_trtllm
(
layer
)
layer
.
_already_called_process_weights_after_loading
=
True
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalizeModular
|
None
:
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
uses the new modular kernel initialization "
"logic. This function should not be called."
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalizeModular
,
layer
:
torch
.
nn
.
Module
,
)
->
mk
.
FusedMoEExpertsModular
:
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
uses the new modular kernel initialization "
"logic. This function should not be called."
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
# TRTLLM MXFP8 path is monolithic and does not use modular kernel config.
return
None
@
property
def
is_monolithic
(
self
)
->
bool
:
return
self
.
mxfp8_backend
==
MxFp8MoeBackend
.
FLASHINFER_TRTLLM
def
apply_monolithic
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
flashinfer.fused_moe.core
import
(
ActivationType
,
Fp8QuantizationType
,
)
assert
self
.
mxfp8_backend
==
MxFp8MoeBackend
.
FLASHINFER_TRTLLM
if
layer
.
enable_eplb
:
raise
NotImplementedError
(
"EPLB is not supported for FlashInfer TRTLLM MXFP8 MoE backend."
)
supported_activations
=
[
MoEActivation
.
SILU
]
if
layer
.
activation
not
in
supported_activations
:
raise
NotImplementedError
(
"FlashInfer TRTLLM MXFP8 MoE supports only "
f
"
{
supported_activations
}
, got
{
layer
.
activation
}
."
)
# Map vLLM MoEActivation to FlashInfer ActivationType.
activation_map
=
{
MoEActivation
.
SILU
:
ActivationType
.
Swiglu
,
MoEActivation
.
RELU2_NO_MUL
:
ActivationType
.
Relu2
,
}
fi_activation_type
:
ActivationType
=
activation_map
[
layer
.
activation
]
# DeepSeekV3 routing requires float32 logits; others expect bfloat16.
if
layer
.
routing_method_type
==
RoutingMethodType
.
DeepSeekV3
:
assert
router_logits
.
dtype
==
torch
.
float32
,
(
"DeepSeekV3 routing requires float32 router_logits, "
f
"got
{
router_logits
.
dtype
}
."
)
else
:
router_logits
=
router_logits
.
to
(
torch
.
bfloat16
)
# Treat 0 as "unset" for compatibility with ungrouped routing configs.
n_group
=
layer
.
num_expert_group
or
None
topk_group
=
layer
.
topk_group
or
None
hidden_states_mxfp8
,
hidden_states_scale
=
mxfp8_e4m3_quantize
(
x
,
is_sf_swizzled_layout
=
False
,
)
kwargs
:
dict
=
dict
(
routing_logits
=
router_logits
,
routing_bias
=
layer
.
e_score_correction_bias
,
hidden_states
=
hidden_states_mxfp8
,
hidden_states_scale
=
hidden_states_scale
,
gemm1_weights
=
layer
.
w13_weight
,
gemm1_weights_scale
=
layer
.
w13_weight_scale
,
gemm2_weights
=
layer
.
w2_weight
,
gemm2_weights_scale
=
layer
.
w2_weight_scale
,
num_experts
=
layer
.
global_num_experts
,
top_k
=
layer
.
top_k
,
# Keep Optional semantics: FlashInfer expects None for non-grouped
# routing (e.g. Qwen3 Renormalize), not 0.
n_group
=
n_group
,
topk_group
=
topk_group
,
intermediate_size
=
layer
.
intermediate_size_per_partition
,
local_expert_offset
=
layer
.
ep_rank
*
layer
.
local_num_experts
,
local_num_experts
=
layer
.
local_num_experts
,
routed_scaling_factor
=
layer
.
routed_scaling_factor
,
routing_method_type
=
layer
.
routing_method_type
,
use_shuffled_weight
=
True
,
weight_layout
=
0
,
fp8_quantization_type
=
Fp8QuantizationType
.
MxFp8
,
)
if
fi_activation_type
!=
ActivationType
.
Swiglu
:
raise
NotImplementedError
(
"FlashInfer TRTLLM MXFP8 MoE supports only Swiglu activation, "
f
"got
{
fi_activation_type
}
."
)
return
flashinfer_trtllm_fp8_block_scale_moe
(
**
kwargs
)
def
apply
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
not
self
.
is_monolithic
raise
NotImplementedError
(
"Non-monolithic MXFP8 MoE path is not yet implemented."
)
# Register the method classes for ModelOptMxFp8Config
# Register the method classes for ModelOptMxFp8Config
ModelOptMxFp8Config
.
LinearMethodCls
=
ModelOptMxFp8LinearMethod
ModelOptMxFp8Config
.
LinearMethodCls
=
ModelOptMxFp8LinearMethod
ModelOptMxFp8Config
.
FusedMoEMethodCls
=
ModelOptMxFp8FusedMoE
ModelOptMxFp8Config
.
KVCacheMethodCls
=
ModelOptFp8KVCacheMethod
ModelOptMxFp8Config
.
KVCacheMethodCls
=
ModelOptFp8KVCacheMethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment