Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e27b8ba3
Unverified
Commit
e27b8ba3
authored
Mar 19, 2026
by
Wei Zhao
Committed by
GitHub
Mar 19, 2026
Browse files
[Bug] Fix fp8 trtllm MoE modular kernel supported routing methods (#37346)
Signed-off-by:
wzhao18
<
wzhao18.sz@gmail.com
>
parent
40b8363b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
59 deletions
+16
-59
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
...model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
+16
-59
No files found.
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
View file @
e27b8ba3
...
...
@@ -69,54 +69,11 @@ class TrtLlmFp8ExpertsBase:
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
return
True
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
"""Supports Fp8 per-tensor, Fp8 block, and MXFP8."""
SUPPORTED_W_A
=
[
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
),
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
),
(
kMxfp8Static
,
kMxfp8Dynamic
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_activation
(
activation
:
MoEActivation
)
->
bool
:
"""Supports only SiLU and RELU^2 non-gated activation."""
return
activation
in
[
MoEActivation
.
SILU
,
MoEActivation
.
RELU2_NO_MUL
]
@
staticmethod
def
_supports_routing_method
(
routing_method
:
RoutingMethodType
,
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
"""Monolithic kernels need to express router support."""
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
# NOTE(dbari): Default is not implemented and should not be enabled until it is
if
(
weight_key
,
activation_key
)
in
[
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
),
(
kMxfp8Static
,
kMxfp8Dynamic
),
]:
# NOTE(rob): potentially allow others here. This is a conservative list.
return
routing_method
in
[
RoutingMethodType
.
DeepSeekV3
,
RoutingMethodType
.
Renormalize
,
RoutingMethodType
.
RenormalizeNaive
,
]
elif
(
weight_key
,
activation_key
)
==
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
):
# NOTE(dbari): as above, potentially allow others here.
return
routing_method
in
[
RoutingMethodType
.
DeepSeekV3
,
RoutingMethodType
.
Llama4
,
RoutingMethodType
.
Renormalize
,
RoutingMethodType
.
RenormalizeNaive
,
]
else
:
raise
ValueError
(
"Unsupported quantization scheme."
)
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
"""Monolithic kernel so only use with naive DP/EP and TP."""
...
...
@@ -125,22 +82,6 @@ class TrtLlmFp8ExpertsBase:
or
moe_parallel_config
.
use_naive_all2all_kernels
)
and
not
moe_parallel_config
.
enable_eplb
@
staticmethod
def
_supports_router_logits_dtype
(
router_logits_dtype
:
torch
.
dtype
|
None
,
routing_method
:
RoutingMethodType
,
)
->
bool
:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if
router_logits_dtype
==
torch
.
float32
:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return
routing_method
==
RoutingMethodType
.
DeepSeekV3
return
True
def
supports_chunking
(
self
)
->
bool
:
return
False
...
...
@@ -306,6 +247,22 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_router_logits_dtype
(
router_logits_dtype
:
torch
.
dtype
|
None
,
routing_method
:
RoutingMethodType
,
)
->
bool
:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if
router_logits_dtype
==
torch
.
float32
:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return
routing_method
==
RoutingMethodType
.
DeepSeekV3
return
True
@
staticmethod
def
_supports_routing_method
(
routing_method
:
RoutingMethodType
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment