Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7d98f09b
Commit
7d98f09b
authored
Feb 03, 2026
by
Michael Goin
Committed by
Robert Shaw
Feb 03, 2026
Browse files
cherry pick
Signed-off-by:
Robert Shaw
<
rshaw@neuralmagic.com
>
parent
daa2784b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
42 additions
and
32 deletions
+42
-32
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
.../model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
+31
-4
vllm/model_executor/layers/fused_moe/oracle/fp8.py
vllm/model_executor/layers/fused_moe/oracle/fp8.py
+4
-4
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+3
-12
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+3
-12
vllm/model_executor/models/minimax_m2.py
vllm/model_executor/models/minimax_m2.py
+1
-0
No files found.
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
View file @
7d98f09b
...
@@ -86,7 +86,23 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
...
@@ -86,7 +86,23 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
return
not
moe_parallel_config
.
enable_eplb
return
not
moe_parallel_config
.
enable_eplb
def
is_supported_config_trtllm
(
def
_supports_router_logits_dtype
(
router_logits_dtype
:
torch
.
dtype
|
None
,
routing_method
:
RoutingMethodType
,
)
->
bool
:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if
router_logits_dtype
==
torch
.
float32
:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return
routing_method
==
RoutingMethodType
.
DeepSeekV3
return
True
def
is_supported_config_trtllm_fp8
(
moe_config
:
FusedMoEConfig
,
moe_config
:
FusedMoEConfig
,
weight_key
:
QuantKey
|
None
,
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
...
@@ -115,13 +131,17 @@ def is_supported_config_trtllm(
...
@@ -115,13 +131,17 @@ def is_supported_config_trtllm(
return
False
,
_make_reason
(
"routing method"
)
return
False
,
_make_reason
(
"routing method"
)
elif
activation_format
!=
mk
.
FusedMoEActivationFormat
.
Standard
:
elif
activation_format
!=
mk
.
FusedMoEActivationFormat
.
Standard
:
return
False
,
_make_reason
(
"activation format"
)
return
False
,
_make_reason
(
"activation format"
)
elif
not
_supports_router_logits_dtype
(
moe_config
.
router_logits_dtype
,
moe_config
.
routing_method
):
return
False
,
_make_reason
(
"float32 router_logits with non-DeepSeekV3 routing"
)
return
True
,
None
return
True
,
None
def
flashinfer_fused_moe_blockscale_fp8
(
def
flashinfer_fused_moe_blockscale_fp8
(
routing_logits
:
torch
.
Tensor
,
routing_logits
:
torch
.
Tensor
,
routing_bias
:
torch
.
Tensor
,
routing_bias
:
torch
.
Tensor
|
None
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w13_weight
:
torch
.
Tensor
,
w13_weight
:
torch
.
Tensor
,
w13_weight_scale_inv
:
torch
.
Tensor
,
w13_weight_scale_inv
:
torch
.
Tensor
,
...
@@ -135,7 +155,7 @@ def flashinfer_fused_moe_blockscale_fp8(
...
@@ -135,7 +155,7 @@ def flashinfer_fused_moe_blockscale_fp8(
expert_offset
:
int
,
expert_offset
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
block_shape
:
list
[
int
],
block_shape
:
list
[
int
],
routing_method_type
:
int
=
int
(
RoutingMethodType
.
DeepSeekV3
)
,
routing_method_type
:
int
,
routed_scaling
:
float
|
None
=
1.0
,
routed_scaling
:
float
|
None
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.utils.flashinfer
import
flashinfer_trtllm_fp8_block_scale_moe
from
vllm.utils.flashinfer
import
flashinfer_trtllm_fp8_block_scale_moe
...
@@ -148,6 +168,13 @@ def flashinfer_fused_moe_blockscale_fp8(
...
@@ -148,6 +168,13 @@ def flashinfer_fused_moe_blockscale_fp8(
# Routing kernel expects #experts <= #threads 512
# Routing kernel expects #experts <= #threads 512
assert
global_num_experts
<=
512
assert
global_num_experts
<=
512
# The DeepSeekV3 routing method requires float32 router logits.
if
routing_method_type
==
RoutingMethodType
.
DeepSeekV3
:
routing_logits
=
routing_logits
.
to
(
torch
.
float32
)
if
routing_bias
is
not
None
:
routing_bias
=
routing_bias
.
to
(
x
.
dtype
)
a_q
,
a_sf
=
per_token_group_quant_fp8
(
x
,
block_shape
[
1
])
a_q
,
a_sf
=
per_token_group_quant_fp8
(
x
,
block_shape
[
1
])
# NOTE: scales of hidden states have to be transposed!
# NOTE: scales of hidden states have to be transposed!
a_sf_t
=
a_sf
.
t
().
contiguous
()
a_sf_t
=
a_sf
.
t
().
contiguous
()
...
@@ -175,7 +202,7 @@ def flashinfer_fused_moe_blockscale_fp8(
...
@@ -175,7 +202,7 @@ def flashinfer_fused_moe_blockscale_fp8(
def
flashinfer_fused_moe_blockscale_fp8_fake
(
def
flashinfer_fused_moe_blockscale_fp8_fake
(
routing_logits
:
torch
.
Tensor
,
routing_logits
:
torch
.
Tensor
,
routing_bias
:
torch
.
Tensor
,
routing_bias
:
torch
.
Tensor
|
None
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w13_weight
:
torch
.
Tensor
,
w13_weight
:
torch
.
Tensor
,
w13_weight_scale_inv
:
torch
.
Tensor
,
w13_weight_scale_inv
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/oracle/fp8.py
View file @
7d98f09b
...
@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a16_moe_quant_config
,
fp8_w8a16_moe_quant_config
,
)
)
from
vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe
import
(
from
vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe
import
(
is_supported_config_trtllm
,
is_supported_config_trtllm
_fp8
,
)
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
MoEPrepareAndFinalizeNoEP
,
...
@@ -212,7 +212,7 @@ def select_fp8_moe_backend(
...
@@ -212,7 +212,7 @@ def select_fp8_moe_backend(
if
fi_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
if
fi_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
backend
=
Fp8MoeBackend
.
FLASHINFER_TRTLLM
backend
=
Fp8MoeBackend
.
FLASHINFER_TRTLLM
supported
,
reason
=
is_supported_config_trtllm
(
supported
,
reason
=
is_supported_config_trtllm
_fp8
(
config
,
weight_key
,
activation_key
,
activation_format
config
,
weight_key
,
activation_key
,
activation_format
)
)
if
supported
:
if
supported
:
...
@@ -239,7 +239,7 @@ def select_fp8_moe_backend(
...
@@ -239,7 +239,7 @@ def select_fp8_moe_backend(
]:
]:
if
backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
if
backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
k_cls
=
None
k_cls
=
None
supported
,
reason
=
is_supported_config_trtllm
(
supported
,
reason
=
is_supported_config_trtllm
_fp8
(
config
,
config
,
weight_key
,
weight_key
,
activation_key
,
activation_key
,
...
@@ -308,7 +308,7 @@ def select_fp8_moe_backend(
...
@@ -308,7 +308,7 @@ def select_fp8_moe_backend(
for
backend
in
AVAILABLE_BACKENDS
:
for
backend
in
AVAILABLE_BACKENDS
:
if
backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
if
backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
k_cls
=
None
k_cls
=
None
supported
,
reason
=
is_supported_config_trtllm
(
supported
,
reason
=
is_supported_config_trtllm
_fp8
(
config
,
config
,
weight_key
,
weight_key
,
activation_key
,
activation_key
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
7d98f09b
...
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe import (
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEConfig
,
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
int4_w4a16_moe_quant_config
,
int4_w4a16_moe_quant_config
,
int4_w4afp8_moe_quant_config
,
int4_w4afp8_moe_quant_config
,
int8_w8a8_moe_quant_config
,
int8_w8a8_moe_quant_config
,
...
@@ -1072,17 +1071,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1072,17 +1071,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
if
self
.
block_quant
:
if
self
.
block_quant
:
import
vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe
# noqa: E501, F401
import
vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe
# noqa: E501, F401
e_score_correction_bias
=
(
layer
.
e_score_correction_bias
.
to
(
x
.
dtype
)
if
layer
.
e_score_correction_bias
is
not
None
else
None
)
routing_method_type
=
layer
.
routing_method_type
return
torch
.
ops
.
vllm
.
flashinfer_fused_moe_blockscale_fp8
(
return
torch
.
ops
.
vllm
.
flashinfer_fused_moe_blockscale_fp8
(
routing_logits
=
router_logits
.
to
(
torch
.
float32
)
routing_logits
=
router_logits
,
if
routing_method_type
==
RoutingMethodType
.
DeepSeekV3
routing_bias
=
layer
.
e_score_correction_bias
,
else
router_logits
,
routing_bias
=
e_score_correction_bias
,
x
=
x
,
x
=
x
,
w13_weight
=
layer
.
w13_weight
,
w13_weight
=
layer
.
w13_weight
,
w13_weight_scale_inv
=
layer
.
w13_weight_scale
,
w13_weight_scale_inv
=
layer
.
w13_weight_scale
,
...
@@ -1096,7 +1087,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1096,7 +1087,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_offset
=
layer
.
ep_rank
*
layer
.
local_num_experts
,
expert_offset
=
layer
.
ep_rank
*
layer
.
local_num_experts
,
local_num_experts
=
layer
.
local_num_experts
,
local_num_experts
=
layer
.
local_num_experts
,
block_shape
=
self
.
weight_block_size
,
block_shape
=
self
.
weight_block_size
,
routing_method_type
=
routing_method_type
,
routing_method_type
=
layer
.
routing_method_type
,
routed_scaling
=
layer
.
routed_scaling_factor
,
routed_scaling
=
layer
.
routed_scaling_factor
,
)
)
else
:
else
:
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
7d98f09b
...
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe import (
)
)
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
)
)
from
vllm.model_executor.layers.fused_moe.layer
import
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe.layer
import
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe.oracle.fp8
import
(
from
vllm.model_executor.layers.fused_moe.oracle.fp8
import
(
...
@@ -990,17 +989,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -990,17 +989,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
self
.
block_quant
:
if
self
.
block_quant
:
import
vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe
# noqa: E501, F401
import
vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe
# noqa: E501, F401
e_score_correction_bias
=
(
layer
.
e_score_correction_bias
.
to
(
x
.
dtype
)
if
layer
.
e_score_correction_bias
is
not
None
else
None
)
routing_method_type
=
layer
.
routing_method_type
return
torch
.
ops
.
vllm
.
flashinfer_fused_moe_blockscale_fp8
(
return
torch
.
ops
.
vllm
.
flashinfer_fused_moe_blockscale_fp8
(
routing_logits
=
router_logits
.
to
(
torch
.
float32
)
routing_logits
=
router_logits
,
if
routing_method_type
==
RoutingMethodType
.
DeepSeekV3
routing_bias
=
layer
.
e_score_correction_bias
,
else
router_logits
,
routing_bias
=
e_score_correction_bias
,
x
=
x
,
x
=
x
,
w13_weight
=
layer
.
w13_weight
,
w13_weight
=
layer
.
w13_weight
,
w13_weight_scale_inv
=
layer
.
w13_weight_scale_inv
,
w13_weight_scale_inv
=
layer
.
w13_weight_scale_inv
,
...
@@ -1014,7 +1005,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1014,7 +1005,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_offset
=
layer
.
ep_rank
*
layer
.
local_num_experts
,
expert_offset
=
layer
.
ep_rank
*
layer
.
local_num_experts
,
local_num_experts
=
layer
.
local_num_experts
,
local_num_experts
=
layer
.
local_num_experts
,
block_shape
=
self
.
weight_block_size
,
block_shape
=
self
.
weight_block_size
,
routing_method_type
=
routing_method_type
,
routing_method_type
=
layer
.
routing_method_type
,
routed_scaling
=
layer
.
routed_scaling_factor
,
routed_scaling
=
layer
.
routed_scaling_factor
,
)
)
else
:
else
:
...
...
vllm/model_executor/models/minimax_m2.py
View file @
7d98f09b
...
@@ -107,6 +107,7 @@ class MiniMaxM2MoE(nn.Module):
...
@@ -107,6 +107,7 @@ class MiniMaxM2MoE(nn.Module):
renormalize
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
prefix
=
f
"
{
prefix
}
.experts"
,
router_logits_dtype
=
torch
.
float32
,
)
)
self
.
gate
=
ReplicatedLinear
(
self
.
gate
=
ReplicatedLinear
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment