Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dc917cce
Unverified
Commit
dc917cce
authored
Jan 22, 2026
by
bnellnm
Committed by
GitHub
Jan 22, 2026
Browse files
[MoE Refactor] Move `select_experts` from `FusedMoEQuantMethod` -> `FusedMoE` (#31996)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
fc56f4a0
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
114 additions
and
103 deletions
+114
-103
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+108
-80
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+6
-23
No files found.
vllm/model_executor/layers/quantization/mxfp4.py
View file @
dc917cce
...
@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE
,
FusedMoE
,
FusedMoEConfig
,
FusedMoEConfig
,
FusedMoEMethodBase
,
FusedMoEMethodBase
,
FusedMoERouter
,
)
)
from
vllm.model_executor.layers.fused_moe
import
modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe
import
modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
...
@@ -890,22 +889,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -890,22 +889,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def
allow_inplace
(
self
)
->
bool
:
def
allow_inplace
(
self
)
->
bool
:
return
True
return
True
@
property
def
is_monolithic
(
self
)
->
bool
:
return
(
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_TRTLLM
or
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_BF16
or
self
.
mxfp4_backend
==
Mxfp4Backend
.
TRITON
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
FusedMoE
,
layer
:
FusedMoE
,
router
:
FusedMoERouter
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
not
self
.
is_monolithic
if
layer
.
enable_eplb
:
if
layer
.
enable_eplb
:
raise
NotImplementedError
(
"EPLB is not supported for mxfp4"
)
raise
NotImplementedError
(
"EPLB is not supported for mxfp4"
)
if
self
.
mxfp4_backend
==
Mxfp4Backend
.
MARLIN
:
if
self
.
mxfp4_backend
==
Mxfp4Backend
.
MARLIN
:
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
return
fused_marlin_moe
(
return
fused_marlin_moe
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
@@ -914,7 +917,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -914,7 +917,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer
.
w2_bias
,
layer
.
w2_bias
,
layer
.
w13_weight_scale
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
layer
.
w2_weight_scale
,
router_logits
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
global_scale1
=
None
,
global_scale1
=
None
,
...
@@ -942,62 +944,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -942,62 +944,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer
.
eplb_state
.
logical_replica_count
,
layer
.
eplb_state
.
logical_replica_count
,
),
"MXFP4 are not supported with this configuration."
),
"MXFP4 are not supported with this configuration."
if
(
assert
(
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_TRTLLM
or
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_BF16
):
from
flashinfer
import
trtllm_fp4_block_scale_moe
if
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_BF16
:
assert
x
.
dtype
==
torch
.
bfloat16
x_quant
=
x
x_scale
=
None
elif
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_TRTLLM
:
from
flashinfer
import
mxfp8_quantize
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
False
)
# to mxfp8
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
*
x
.
shape
[:
-
1
],
-
1
)
trtllm_gen_output
=
trtllm_fp4_block_scale_moe
(
router_logits
.
to
(
torch
.
bfloat16
),
None
,
# routing_bias
x_quant
,
x_scale
,
layer
.
w13_weight
,
# uint8 (e2m1 x 2)
layer
.
w13_weight_scale
,
# uint8 (e4m3 x 2)
layer
.
w13_bias
,
# fp32 per expert per channel
layer
.
gemm1_alpha
,
# fp32 per expert
layer
.
gemm1_beta
,
# fp32 per expert
layer
.
gemm1_clamp_limit
,
# fp32 per expert
layer
.
w2_weight
,
# uint8 (e2m1 x 2)
layer
.
w2_weight_scale
,
# ue8m0
layer
.
w2_bias
,
# fp32 per expert per channel
None
,
# output1_scale_scalar
None
,
# output1_scale_gate_scalar
None
,
# output2_scale_scalar
layer
.
global_num_experts
,
layer
.
top_k
,
None
,
# n_group
None
,
# topk_group
self
.
intermediate_size
,
# padded to multiple of 256
layer
.
ep_rank
*
layer
.
local_num_experts
,
# local_expert_offset
self
.
num_experts
,
# local num experts
None
,
# routed_scaling_factor
1
if
layer
.
renormalize
else
0
,
# routing_method_type, renormalize
True
,
# do finalize
tune_max_num_tokens
=
max
(
self
.
max_capture_size
,
1
),
)[
0
]
return
trtllm_gen_output
elif
(
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_CUTLASS
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_CUTLASS
or
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM90_FI_MXFP4_BF16
or
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM90_FI_MXFP4_BF16
):
from
vllm.utils.flashinfer
import
flashinfer_cutlass_fused_moe
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
)
from
vllm.utils.flashinfer
import
flashinfer_cutlass_fused_moe
# Backend-specific preparation
# Backend-specific preparation
if
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_CUTLASS
:
if
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_CUTLASS
:
...
@@ -1036,7 +987,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -1036,7 +987,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
)
output
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
bfloat16
)
output
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
bfloat16
)
_
=
flashinfer_cutlass_fused_moe
(
flashinfer_cutlass_fused_moe
(
input
=
fi_input
,
input
=
fi_input
,
token_selected_experts
=
topk_ids
.
to
(
torch
.
int
).
contiguous
(),
token_selected_experts
=
topk_ids
.
to
(
torch
.
int
).
contiguous
(),
token_final_scales
=
topk_weights
,
token_final_scales
=
topk_weights
,
...
@@ -1057,6 +1009,79 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -1057,6 +1009,79 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
)
return
output
return
output
def
apply_monolithic
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
is_monolithic
if
layer
.
enable_eplb
:
raise
NotImplementedError
(
"EPLB is not supported for mxfp4"
)
assert
_can_support_mxfp4
(
layer
.
use_grouped_topk
,
layer
.
topk_group
,
layer
.
num_expert_group
,
layer
.
expert_map
,
layer
.
custom_routing_function
,
layer
.
e_score_correction_bias
,
layer
.
apply_router_weight_on_input
,
layer
.
scoring_func
,
layer
.
activation
,
layer
.
eplb_state
.
expert_load_view
,
layer
.
eplb_state
.
logical_to_physical_map
,
layer
.
eplb_state
.
logical_replica_count
,
),
"MXFP4 are not supported with this configuration."
if
(
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_TRTLLM
or
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_BF16
):
from
flashinfer
import
trtllm_fp4_block_scale_moe
if
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_BF16
:
assert
x
.
dtype
==
torch
.
bfloat16
x_quant
=
x
x_scale
=
None
elif
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_TRTLLM
:
from
flashinfer
import
mxfp8_quantize
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
False
)
# to mxfp8
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
*
x
.
shape
[:
-
1
],
-
1
)
trtllm_gen_output
=
trtllm_fp4_block_scale_moe
(
router_logits
.
to
(
torch
.
bfloat16
),
None
,
# routing_bias
x_quant
,
x_scale
,
layer
.
w13_weight
,
# uint8 (e2m1 x 2)
layer
.
w13_weight_scale
,
# uint8 (e4m3 x 2)
layer
.
w13_bias
,
# fp32 per expert per channel
layer
.
gemm1_alpha
,
# fp32 per expert
layer
.
gemm1_beta
,
# fp32 per expert
layer
.
gemm1_clamp_limit
,
# fp32 per expert
layer
.
w2_weight
,
# uint8 (e2m1 x 2)
layer
.
w2_weight_scale
,
# ue8m0
layer
.
w2_bias
,
# fp32 per expert per channel
None
,
# output1_scale_scalar
None
,
# output1_scale_gate_scalar
None
,
# output2_scale_scalar
layer
.
global_num_experts
,
layer
.
top_k
,
None
,
# n_group
None
,
# topk_group
self
.
intermediate_size
,
# padded to multiple of 256
layer
.
ep_rank
*
layer
.
local_num_experts
,
# local_expert_offset
self
.
num_experts
,
# local num experts
None
,
# routed_scaling_factor
1
if
layer
.
renormalize
else
0
,
# routing_method_type, renormalize
True
,
# do finalize
tune_max_num_tokens
=
max
(
self
.
max_capture_size
,
1
),
)[
0
]
return
trtllm_gen_output
elif
self
.
mxfp4_backend
==
Mxfp4Backend
.
TRITON
:
elif
self
.
mxfp4_backend
==
Mxfp4Backend
.
TRITON
:
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
# noqa: E501
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
# noqa: E501
triton_kernel_moe_forward
,
triton_kernel_moe_forward
,
...
@@ -1119,10 +1144,13 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
...
@@ -1119,10 +1144,13 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
experts_start_id
=
ep_rank_start
,
experts_start_id
=
ep_rank_start
,
)
)
def
apply
(
@
property
def
is_monolithic
(
self
)
->
bool
:
return
True
def
apply_monolithic
(
self
,
self
,
layer
:
FusedMoE
,
layer
:
FusedMoE
,
router
:
FusedMoERouter
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
dc917cce
...
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE
,
FusedMoE
,
FusedMoEConfig
,
FusedMoEConfig
,
FusedMoEMethodBase
,
FusedMoEMethodBase
,
FusedMoERouter
,
FusedMoeWeightScaleSupported
,
FusedMoeWeightScaleSupported
,
)
)
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
...
@@ -351,15 +350,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -351,15 +350,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def
apply
(
def
apply
(
self
,
self
,
layer
:
FusedMoE
,
layer
:
FusedMoE
,
router
:
FusedMoERouter
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
if
self
.
rocm_aiter_moe_enabled
:
if
self
.
rocm_aiter_moe_enabled
:
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
rocm_aiter_fused_experts
,
rocm_aiter_fused_experts
,
...
@@ -388,7 +382,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -388,7 +382,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
None
,
None
,
layer
.
w13_weight_scale
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
layer
.
w2_weight_scale
,
router_logits
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
quant_type_id
=
scalar_types
.
float8_e4m3fn
.
id
,
quant_type_id
=
scalar_types
.
float8_e4m3fn
.
id
,
...
@@ -544,15 +537,10 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -544,15 +537,10 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
def
apply
(
def
apply
(
self
,
self
,
layer
:
FusedMoE
,
layer
:
FusedMoE
,
router
:
FusedMoERouter
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
rocm_aiter_fused_experts
,
rocm_aiter_fused_experts
,
)
)
...
@@ -753,15 +741,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
...
@@ -753,15 +741,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def
apply
(
def
apply
(
self
,
self
,
layer
:
FusedMoE
,
layer
:
FusedMoE
,
router
:
FusedMoERouter
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
if
not
self
.
emulate
:
if
not
self
.
emulate
:
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
rocm_aiter_fused_experts
,
rocm_aiter_fused_experts
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment