Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dc917cce
Unverified
Commit
dc917cce
authored
Jan 22, 2026
by
bnellnm
Committed by
GitHub
Jan 22, 2026
Browse files
[MoE Refactor] Move `select_experts` from `FusedMoEQuantMethod` -> `FusedMoE` (#31996)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
fc56f4a0
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
114 additions
and
103 deletions
+114
-103
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+108
-80
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+6
-23
No files found.
vllm/model_executor/layers/quantization/mxfp4.py
View file @
dc917cce
...
@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE
,
FusedMoE
,
FusedMoEConfig
,
FusedMoEConfig
,
FusedMoEMethodBase
,
FusedMoEMethodBase
,
FusedMoERouter
,
)
)
from
vllm.model_executor.layers.fused_moe
import
modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe
import
modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
...
@@ -890,22 +889,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -890,22 +889,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def
allow_inplace
(
self
)
->
bool
:
def
allow_inplace
(
self
)
->
bool
:
return
True
return
True
@
property
def
is_monolithic
(
self
)
->
bool
:
return
(
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_TRTLLM
or
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_BF16
or
self
.
mxfp4_backend
==
Mxfp4Backend
.
TRITON
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
FusedMoE
,
layer
:
FusedMoE
,
router
:
FusedMoERouter
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
not
self
.
is_monolithic
if
layer
.
enable_eplb
:
if
layer
.
enable_eplb
:
raise
NotImplementedError
(
"EPLB is not supported for mxfp4"
)
raise
NotImplementedError
(
"EPLB is not supported for mxfp4"
)
if
self
.
mxfp4_backend
==
Mxfp4Backend
.
MARLIN
:
if
self
.
mxfp4_backend
==
Mxfp4Backend
.
MARLIN
:
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
return
fused_marlin_moe
(
return
fused_marlin_moe
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
@@ -914,7 +917,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -914,7 +917,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer
.
w2_bias
,
layer
.
w2_bias
,
layer
.
w13_weight_scale
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
layer
.
w2_weight_scale
,
router_logits
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
global_scale1
=
None
,
global_scale1
=
None
,
...
@@ -942,6 +944,98 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -942,6 +944,98 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer
.
eplb_state
.
logical_replica_count
,
layer
.
eplb_state
.
logical_replica_count
,
),
"MXFP4 are not supported with this configuration."
),
"MXFP4 are not supported with this configuration."
assert
(
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_CUTLASS
or
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM90_FI_MXFP4_BF16
)
from
vllm.utils.flashinfer
import
flashinfer_cutlass_fused_moe
# Backend-specific preparation
if
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_CUTLASS
:
from
flashinfer
import
mxfp8_quantize
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
True
,
32
)
fake_input_scale
=
torch
.
ones
(
self
.
num_experts
,
device
=
x
.
device
)
quant_scales
=
[
layer
.
w13_weight_scale
.
contiguous
().
view
(
torch
.
int32
),
fake_input_scale
,
layer
.
w2_weight_scale
.
contiguous
().
view
(
torch
.
int32
),
fake_input_scale
,
]
fi_input
=
x_quant
extra_kwargs
=
dict
(
use_mxfp8_act_scaling
=
True
,
input_sf
=
x_scale
,
fc1_expert_weights
=
layer
.
w13_weight
.
contiguous
().
view
(
torch
.
long
),
fc2_expert_weights
=
layer
.
w2_weight
.
contiguous
().
view
(
torch
.
long
),
)
elif
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM90_FI_MXFP4_BF16
:
assert
x
.
dtype
==
torch
.
bfloat16
quant_scales
=
[
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
]
fi_input
=
x
extra_kwargs
=
dict
(
use_w4_group_scaling
=
True
,
fc1_expert_weights
=
layer
.
w13_weight
,
fc2_expert_weights
=
layer
.
w2_weight
,
)
output
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
bfloat16
)
flashinfer_cutlass_fused_moe
(
input
=
fi_input
,
token_selected_experts
=
topk_ids
.
to
(
torch
.
int
).
contiguous
(),
token_final_scales
=
topk_weights
,
output_dtype
=
torch
.
bfloat16
,
output
=
output
,
quant_scales
=
quant_scales
,
fc1_expert_biases
=
layer
.
w13_bias
,
fc2_expert_biases
=
layer
.
w2_bias
,
swiglu_alpha
=
layer
.
gemm1_alpha
,
swiglu_beta
=
layer
.
gemm1_beta
,
swiglu_limit
=
layer
.
gemm1_clamp_limit
,
tp_size
=
self
.
moe
.
tp_size
,
tp_rank
=
self
.
moe
.
tp_rank
,
ep_size
=
self
.
moe
.
ep_size
,
ep_rank
=
self
.
moe
.
ep_rank
,
tune_max_num_tokens
=
max
(
self
.
max_capture_size
,
1
),
**
extra_kwargs
,
)
return
output
def
apply_monolithic
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
is_monolithic
if
layer
.
enable_eplb
:
raise
NotImplementedError
(
"EPLB is not supported for mxfp4"
)
assert
_can_support_mxfp4
(
layer
.
use_grouped_topk
,
layer
.
topk_group
,
layer
.
num_expert_group
,
layer
.
expert_map
,
layer
.
custom_routing_function
,
layer
.
e_score_correction_bias
,
layer
.
apply_router_weight_on_input
,
layer
.
scoring_func
,
layer
.
activation
,
layer
.
eplb_state
.
expert_load_view
,
layer
.
eplb_state
.
logical_to_physical_map
,
layer
.
eplb_state
.
logical_replica_count
,
),
"MXFP4 are not supported with this configuration."
if
(
if
(
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_TRTLLM
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_TRTLLM
or
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_BF16
or
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_BF16
...
@@ -988,75 +1082,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -988,75 +1082,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
tune_max_num_tokens
=
max
(
self
.
max_capture_size
,
1
),
tune_max_num_tokens
=
max
(
self
.
max_capture_size
,
1
),
)[
0
]
)[
0
]
return
trtllm_gen_output
return
trtllm_gen_output
elif
(
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_CUTLASS
or
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM90_FI_MXFP4_BF16
):
from
vllm.utils.flashinfer
import
flashinfer_cutlass_fused_moe
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
# Backend-specific preparation
if
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_CUTLASS
:
from
flashinfer
import
mxfp8_quantize
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
True
,
32
)
fake_input_scale
=
torch
.
ones
(
self
.
num_experts
,
device
=
x
.
device
)
quant_scales
=
[
layer
.
w13_weight_scale
.
contiguous
().
view
(
torch
.
int32
),
fake_input_scale
,
layer
.
w2_weight_scale
.
contiguous
().
view
(
torch
.
int32
),
fake_input_scale
,
]
fi_input
=
x_quant
extra_kwargs
=
dict
(
use_mxfp8_act_scaling
=
True
,
input_sf
=
x_scale
,
fc1_expert_weights
=
layer
.
w13_weight
.
contiguous
().
view
(
torch
.
long
),
fc2_expert_weights
=
layer
.
w2_weight
.
contiguous
().
view
(
torch
.
long
),
)
elif
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM90_FI_MXFP4_BF16
:
assert
x
.
dtype
==
torch
.
bfloat16
quant_scales
=
[
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
]
fi_input
=
x
extra_kwargs
=
dict
(
use_w4_group_scaling
=
True
,
fc1_expert_weights
=
layer
.
w13_weight
,
fc2_expert_weights
=
layer
.
w2_weight
,
)
output
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
bfloat16
)
_
=
flashinfer_cutlass_fused_moe
(
input
=
fi_input
,
token_selected_experts
=
topk_ids
.
to
(
torch
.
int
).
contiguous
(),
token_final_scales
=
topk_weights
,
output_dtype
=
torch
.
bfloat16
,
output
=
output
,
quant_scales
=
quant_scales
,
fc1_expert_biases
=
layer
.
w13_bias
,
fc2_expert_biases
=
layer
.
w2_bias
,
swiglu_alpha
=
layer
.
gemm1_alpha
,
swiglu_beta
=
layer
.
gemm1_beta
,
swiglu_limit
=
layer
.
gemm1_clamp_limit
,
tp_size
=
self
.
moe
.
tp_size
,
tp_rank
=
self
.
moe
.
tp_rank
,
ep_size
=
self
.
moe
.
ep_size
,
ep_rank
=
self
.
moe
.
ep_rank
,
tune_max_num_tokens
=
max
(
self
.
max_capture_size
,
1
),
**
extra_kwargs
,
)
return
output
elif
self
.
mxfp4_backend
==
Mxfp4Backend
.
TRITON
:
elif
self
.
mxfp4_backend
==
Mxfp4Backend
.
TRITON
:
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
# noqa: E501
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
# noqa: E501
triton_kernel_moe_forward
,
triton_kernel_moe_forward
,
...
@@ -1119,10 +1144,13 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
...
@@ -1119,10 +1144,13 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
experts_start_id
=
ep_rank_start
,
experts_start_id
=
ep_rank_start
,
)
)
def
apply
(
@
property
def
is_monolithic
(
self
)
->
bool
:
return
True
def
apply_monolithic
(
self
,
self
,
layer
:
FusedMoE
,
layer
:
FusedMoE
,
router
:
FusedMoERouter
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
dc917cce
...
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE
,
FusedMoE
,
FusedMoEConfig
,
FusedMoEConfig
,
FusedMoEMethodBase
,
FusedMoEMethodBase
,
FusedMoERouter
,
FusedMoeWeightScaleSupported
,
FusedMoeWeightScaleSupported
,
)
)
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
...
@@ -351,15 +350,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -351,15 +350,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def
apply
(
def
apply
(
self
,
self
,
layer
:
FusedMoE
,
layer
:
FusedMoE
,
router
:
FusedMoERouter
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
if
self
.
rocm_aiter_moe_enabled
:
if
self
.
rocm_aiter_moe_enabled
:
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
rocm_aiter_fused_experts
,
rocm_aiter_fused_experts
,
...
@@ -388,7 +382,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -388,7 +382,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
None
,
None
,
layer
.
w13_weight_scale
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
layer
.
w2_weight_scale
,
router_logits
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
quant_type_id
=
scalar_types
.
float8_e4m3fn
.
id
,
quant_type_id
=
scalar_types
.
float8_e4m3fn
.
id
,
...
@@ -544,15 +537,10 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -544,15 +537,10 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
def
apply
(
def
apply
(
self
,
self
,
layer
:
FusedMoE
,
layer
:
FusedMoE
,
router
:
FusedMoERouter
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
rocm_aiter_fused_experts
,
rocm_aiter_fused_experts
,
)
)
...
@@ -753,15 +741,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
...
@@ -753,15 +741,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def
apply
(
def
apply
(
self
,
self
,
layer
:
FusedMoE
,
layer
:
FusedMoE
,
router
:
FusedMoERouter
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
if
not
self
.
emulate
:
if
not
self
.
emulate
:
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
rocm_aiter_fused_experts
,
rocm_aiter_fused_experts
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment