Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
19ec9a0a
Unverified
Commit
19ec9a0a
authored
Apr 14, 2026
by
bnellnm
Committed by
GitHub
Apr 14, 2026
Browse files
[MoE Refactor] Refactor ZeroExpertFusedMoE into new framework (#35549)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
1a9353bb
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
461 additions
and
204 deletions
+461
-204
tests/kernels/moe/test_zero_expert_moe.py
tests/kernels/moe/test_zero_expert_moe.py
+282
-0
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+0
-4
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+3
-0
vllm/model_executor/layers/fused_moe/router/router_factory.py
.../model_executor/layers/fused_moe/router/router_factory.py
+38
-4
vllm/model_executor/layers/fused_moe/router/zero_expert_router.py
...el_executor/layers/fused_moe/router/zero_expert_router.py
+115
-0
vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py
...model_executor/layers/fused_moe/runner/moe_runner_base.py
+19
-1
vllm/model_executor/layers/fused_moe/zero_expert_fused_moe.py
.../model_executor/layers/fused_moe/zero_expert_fused_moe.py
+0
-189
vllm/model_executor/models/longcat_flash.py
vllm/model_executor/models/longcat_flash.py
+4
-6
No files found.
tests/kernels/moe/test_zero_expert_moe.py
0 → 100644
View file @
19ec9a0a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for FusedMoE with zero experts.
Verifies that:
- The ZeroExpertRouter is properly created and used as the layer router.
- A forward pass through FusedMoE with zero experts produces correct output.
- The output decomposes correctly into real expert + zero expert contributions.
Note: tests generated with Claude.
"""
import
pytest
import
torch
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.router.zero_expert_router
import
(
ZeroExpertRouter
,
)
from
vllm.v1.worker.workspace
import
init_workspace_manager
@
pytest
.
fixture
def
zero_expert_moe
(
dist_init
,
default_vllm_config
):
"""Create a FusedMoE layer with zero experts."""
num_experts
=
4
top_k
=
2
# hidden_size must be >= 256 for the zero expert identity kernel to
# produce output (its BLOCK_SIZE=256 causes grid=0 when hidden_dim<256).
hidden_size
=
256
intermediate_size
=
512
zero_expert_num
=
1
e_score_correction_bias
=
torch
.
zeros
(
num_experts
+
zero_expert_num
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
vllm_config
=
VllmConfig
()
vllm_config
.
compilation_config
.
static_forward_context
=
dict
()
with
set_current_vllm_config
(
vllm_config
),
set_forward_context
(
None
,
vllm_config
):
init_workspace_manager
(
torch
.
accelerator
.
current_device_index
())
layer
=
FusedMoE
(
zero_expert_type
=
"identity"
,
e_score_correction_bias
=
e_score_correction_bias
,
num_experts
=
num_experts
,
top_k
=
top_k
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
torch
.
bfloat16
,
prefix
=
"test_zero_expert_moe"
,
renormalize
=
False
,
routed_scaling_factor
=
1.0
,
scoring_func
=
"softmax"
,
).
cuda
()
layer
.
quant_method
.
process_weights_after_loading
(
layer
)
yield
layer
,
vllm_config
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
32
])
def
test_zero_expert_moe_router_is_zero_expert_router
(
zero_expert_moe
,
num_tokens
):
"""Verify that FusedMoE with zero_expert_type creates a ZeroExpertRouter."""
layer
,
_
=
zero_expert_moe
assert
isinstance
(
layer
.
router
,
ZeroExpertRouter
),
(
f
"Expected ZeroExpertRouter but got
{
type
(
layer
.
router
).
__name__
}
."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
32
])
def
test_zero_expert_moe_no_custom_routing_fn
(
zero_expert_moe
,
num_tokens
):
"""Verify that custom_routing_function is not set (routing is handled
by ZeroExpertRouter, not a memoizing closure)."""
layer
,
_
=
zero_expert_moe
assert
layer
.
custom_routing_function
is
None
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
32
])
def
test_zero_expert_moe_forward
(
zero_expert_moe
,
num_tokens
):
"""Run a forward pass through FusedMoE with zero experts and verify output shape."""
layer
,
vllm_config
=
zero_expert_moe
hidden_size
=
layer
.
hidden_size
num_experts
=
4
zero_expert_num
=
1
total_experts
=
num_experts
+
zero_expert_num
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
router_logits
=
torch
.
randn
(
num_tokens
,
total_experts
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# Initialize weights to small random values to avoid NaN from
# uninitialized memory.
with
torch
.
no_grad
():
for
param
in
layer
.
parameters
():
if
param
.
dtype
.
is_floating_point
:
param
.
normal_
(
0
,
0.01
)
with
set_current_vllm_config
(
vllm_config
),
set_forward_context
(
None
,
vllm_config
):
get_forward_context
().
all_moe_layers
=
None
output
=
layer
.
forward
(
hidden_states
,
router_logits
)
assert
output
.
shape
==
hidden_states
.
shape
,
(
f
"Expected output shape
{
hidden_states
.
shape
}
, got
{
output
.
shape
}
"
)
assert
output
.
dtype
==
hidden_states
.
dtype
assert
not
torch
.
isnan
(
output
).
any
(),
"Output contains NaN values"
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
32
])
def
test_zero_expert_moe_output_decomposition
(
zero_expert_moe
,
num_tokens
):
"""Validate that the FusedMoE output equals a plain FusedMoE
output (real experts only) plus the zero expert contribution.
The key invariant is:
zero_layer.forward(h, r_full) == plain_layer.forward(h, r_real)
+ zero_expert_output
We create a plain FusedMoE layer with the same weights and real-expert-only
router logits, compute the zero expert output via the ZeroExpertRouter, and
verify the sum matches the FusedMoE output.
"""
layer
,
vllm_config
=
zero_expert_moe
num_experts
=
4
zero_expert_num
=
1
total_experts
=
num_experts
+
zero_expert_num
hidden_states
=
torch
.
randn
(
num_tokens
,
layer
.
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
router_logits
=
torch
.
randn
(
num_tokens
,
total_experts
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
with
torch
.
no_grad
():
for
param
in
layer
.
parameters
():
if
param
.
dtype
.
is_floating_point
:
param
.
normal_
(
0
,
0.01
)
with
set_current_vllm_config
(
vllm_config
),
set_forward_context
(
None
,
vllm_config
):
get_forward_context
().
all_moe_layers
=
None
# Create a plain FusedMoE layer with the same config but no zero
# experts. Use a separate prefix to avoid collision.
plain_layer
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
layer
.
top_k
,
hidden_size
=
layer
.
hidden_size
,
intermediate_size
=
layer
.
intermediate_size_per_partition
,
params_dtype
=
torch
.
bfloat16
,
prefix
=
"test_zero_expert_moe_plain"
,
renormalize
=
False
,
scoring_func
=
"softmax"
,
e_score_correction_bias
=
layer
.
e_score_correction_bias
,
).
cuda
()
# Share weights from the zero expert layer.
plain_layer
.
w13_weight
.
data
.
copy_
(
layer
.
w13_weight
.
data
)
plain_layer
.
w2_weight
.
data
.
copy_
(
layer
.
w2_weight
.
data
)
plain_layer
.
quant_method
.
process_weights_after_loading
(
plain_layer
)
# Compute routing via the ZeroExpertRouter. This produces masked
# topk_weights/topk_ids (zero expert entries have weight=0, id=0)
# and stores zero_expert_output as a side effect.
topk_weights
,
topk_ids
=
layer
.
router
.
select_experts
(
hidden_states
,
router_logits
)
zero_output
=
layer
.
router
.
zero_expert_output
# Compute real expert output using the plain layer with the masked
# routing from the ZeroExpertRouter.
real_output
=
plain_layer
.
quant_method
.
apply
(
layer
=
plain_layer
,
x
=
hidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
shared_experts_input
=
None
,
)
# Get the combined output from the zero expert layer.
full_output
=
layer
.
forward
(
hidden_states
,
router_logits
)
assert
zero_output
is
not
None
,
"Zero expert output should not be None"
assert
not
torch
.
isnan
(
real_output
).
any
(),
"Real expert output has NaN"
assert
not
torch
.
isnan
(
zero_output
).
any
(),
"Zero expert output has NaN"
assert
not
torch
.
isnan
(
full_output
).
any
(),
"Full output has NaN"
expected
=
real_output
+
zero_output
torch
.
testing
.
assert_close
(
full_output
,
expected
,
atol
=
0
,
rtol
=
0
,
msg
=
"FusedMoE output should equal plain FusedMoE output "
"plus zero expert contribution"
,
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
32
])
def
test_zero_expert_moe_zero_expert_is_identity
(
zero_expert_moe
,
num_tokens
):
"""Validate zero expert identity behavior.
When routing strongly favors the zero expert, its contribution should
be a scaled version of hidden_states (identity operation). We verify
this by manually computing the expected zero expert output from the
routing weights and comparing against what the router produces.
"""
layer
,
vllm_config
=
zero_expert_moe
num_experts
=
4
zero_expert_num
=
1
total_experts
=
num_experts
+
zero_expert_num
hidden_states
=
torch
.
randn
(
num_tokens
,
layer
.
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
# Strongly bias toward the zero expert (index 4).
router_logits
=
torch
.
full
(
(
num_tokens
,
total_experts
),
-
10.0
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
router_logits
[:,
num_experts
]
=
10.0
# zero expert gets high logit
with
torch
.
no_grad
():
for
param
in
layer
.
parameters
():
if
param
.
dtype
.
is_floating_point
:
param
.
normal_
(
0
,
0.01
)
with
set_current_vllm_config
(
vllm_config
),
set_forward_context
(
None
,
vllm_config
):
get_forward_context
().
all_moe_layers
=
None
# Run routing to get topk_weights/topk_ids before masking.
from
vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router
import
(
fused_topk_bias
,
)
topk_weights
,
topk_ids
=
fused_topk_bias
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
e_score_correction_bias
=
layer
.
router
.
e_score_correction_bias
.
data
,
topk
=
layer
.
top_k
,
renormalize
=
layer
.
router
.
renormalize
,
scoring_func
=
layer
.
router
.
scoring_func
,
)
# Manually compute expected zero expert identity output:
# For each token, sum routing weights assigned to zero expert slots,
# then multiply by hidden_states.
zero_mask
=
topk_ids
>=
num_experts
zero_weight_per_token
=
(
topk_weights
*
zero_mask
.
float
()).
sum
(
dim
=-
1
,
keepdim
=
True
)
expected_zero_output
=
(
hidden_states
.
float
()
*
zero_weight_per_token
).
to
(
hidden_states
.
dtype
)
# Run routing directly to trigger zero expert computation
# without going through the runner (which consumes the output).
layer
.
router
.
select_experts
(
hidden_states
,
router_logits
)
actual_zero_output
=
layer
.
router
.
zero_expert_output
assert
actual_zero_output
is
not
None
assert
zero_mask
.
any
(),
(
"With high zero expert logit, at least some slots should route "
"to the zero expert"
)
torch
.
testing
.
assert_close
(
actual_zero_output
,
expected_zero_output
,
atol
=
1e-3
,
rtol
=
1e-3
,
msg
=
"Zero expert identity output should equal "
"hidden_states * sum(zero_expert_weights)"
,
)
vllm/model_executor/layers/fused_moe/__init__.py
View file @
19ec9a0a
...
@@ -33,9 +33,6 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
...
@@ -33,9 +33,6 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from
vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method
import
(
from
vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method
import
(
UnquantizedFusedMoEMethod
,
UnquantizedFusedMoEMethod
,
)
)
from
vllm.model_executor.layers.fused_moe.zero_expert_fused_moe
import
(
ZeroExpertFusedMoE
,
)
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.triton_utils
import
HAS_TRITON
_config
:
dict
[
str
,
Any
]
|
None
=
None
_config
:
dict
[
str
,
Any
]
|
None
=
None
...
@@ -68,7 +65,6 @@ __all__ = [
...
@@ -68,7 +65,6 @@ __all__ = [
"GateLinear"
,
"GateLinear"
,
"RoutingMethodType"
,
"RoutingMethodType"
,
"SharedFusedMoE"
,
"SharedFusedMoE"
,
"ZeroExpertFusedMoE"
,
"activation_without_mul"
,
"activation_without_mul"
,
"apply_moe_activation"
,
"apply_moe_activation"
,
"override_config"
,
"override_config"
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
19ec9a0a
...
@@ -274,6 +274,7 @@ class FusedMoE(PluggableLayer):
...
@@ -274,6 +274,7 @@ class FusedMoE(PluggableLayer):
gate
:
torch
.
nn
.
Module
|
None
=
None
,
gate
:
torch
.
nn
.
Module
|
None
=
None
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
routed_input_transform
:
torch
.
nn
.
Module
|
None
=
None
,
routed_input_transform
:
torch
.
nn
.
Module
|
None
=
None
,
zero_expert_type
:
str
|
None
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -462,6 +463,8 @@ class FusedMoE(PluggableLayer):
...
@@ -462,6 +463,8 @@ class FusedMoE(PluggableLayer):
# TODO(bnell): once we can construct the MK at init time, we
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
# can make this a value.
indices_type_getter
=
lambda
:
self
.
quant_method
.
topk_indices_dtype
,
indices_type_getter
=
lambda
:
self
.
quant_method
.
topk_indices_dtype
,
zero_expert_type
=
zero_expert_type
,
num_logical_experts
=
self
.
logical_num_experts
,
)
)
self
.
routing_method_type
:
RoutingMethodType
=
self
.
router
.
routing_method_type
self
.
routing_method_type
:
RoutingMethodType
=
self
.
router
.
routing_method_type
...
...
vllm/model_executor/layers/fused_moe/router/router_factory.py
View file @
19ec9a0a
...
@@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
...
@@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
from
vllm.model_executor.layers.fused_moe.router.routing_simulator_router
import
(
from
vllm.model_executor.layers.fused_moe.router.routing_simulator_router
import
(
RoutingSimulatorRouter
,
RoutingSimulatorRouter
,
)
)
from
vllm.model_executor.layers.fused_moe.router.zero_expert_router
import
(
ZeroExpertRouter
,
)
EMPTY_EPLB_STATE
:
EplbLayerState
=
EplbLayerState
()
EMPTY_EPLB_STATE
:
EplbLayerState
=
EplbLayerState
()
...
@@ -49,6 +52,9 @@ def create_fused_moe_router(
...
@@ -49,6 +52,9 @@ def create_fused_moe_router(
# eplb parameters
# eplb parameters
enable_eplb
:
bool
=
False
,
enable_eplb
:
bool
=
False
,
eplb_state
:
EplbLayerState
=
EMPTY_EPLB_STATE
,
eplb_state
:
EplbLayerState
=
EMPTY_EPLB_STATE
,
# zero expert parameters
zero_expert_type
:
str
|
None
=
None
,
num_logical_experts
:
int
|
None
=
None
,
)
->
FusedMoERouter
:
)
->
FusedMoERouter
:
"""
"""
Factory function to create the appropriate FusedMoERouter subclass based on
Factory function to create the appropriate FusedMoERouter subclass based on
...
@@ -56,10 +62,11 @@ def create_fused_moe_router(
...
@@ -56,10 +62,11 @@ def create_fused_moe_router(
The selection logic follows this priority order:
The selection logic follows this priority order:
1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set
1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set
2. GroupedTopKRouter - if use_grouped_topk is True
2. ZeroExpertRouter - if zero_expert_type is not None
3. CustomRoutingRouter - if custom_routing_function is not None
3. GroupedTopKRouter - if use_grouped_topk is True
4. FusedTopKBiasRouter - if e_score_correction_bias is not None
4. CustomRoutingRouter - if custom_routing_function is not None
5. FusedTopKRouter - default fallback
5. FusedTopKBiasRouter - if e_score_correction_bias is not None
6. FusedTopKRouter - default fallback
Common arguments:
Common arguments:
top_k: Number of experts to select per token
top_k: Number of experts to select per token
...
@@ -86,6 +93,12 @@ def create_fused_moe_router(
...
@@ -86,6 +93,12 @@ def create_fused_moe_router(
enable_eplb: Whether EPLB is enabled
enable_eplb: Whether EPLB is enabled
eplb_state: EPLB (Expert Parallelism Load Balancing) state
eplb_state: EPLB (Expert Parallelism Load Balancing) state
Zero expert arguments:
zero_expert_type: Type of zero expert (e.g. identity). If not None,
creates a ZeroExpertRouter.
num_logical_experts: Number of real (non-zero) experts. Required when
zero_expert_type is not None.
Returns:
Returns:
An instance of the appropriate FusedMoERouter subclass
An instance of the appropriate FusedMoERouter subclass
"""
"""
...
@@ -100,6 +113,27 @@ def create_fused_moe_router(
...
@@ -100,6 +113,27 @@ def create_fused_moe_router(
indices_type_getter
=
indices_type_getter
,
indices_type_getter
=
indices_type_getter
,
)
)
if
zero_expert_type
is
not
None
:
assert
num_logical_experts
is
not
None
,
(
"num_logical_experts is required when zero_expert_type is set"
)
assert
e_score_correction_bias
is
not
None
,
(
"e_score_correction_bias is required when zero_expert_type is set"
)
return
ZeroExpertRouter
(
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
eplb_state
=
eplb_state
,
e_score_correction_bias
=
e_score_correction_bias
,
num_logical_experts
=
num_logical_experts
,
zero_expert_type
=
zero_expert_type
,
scoring_func
=
scoring_func
,
renormalize
=
renormalize
,
routed_scaling_factor
=
routed_scaling_factor
,
enable_eplb
=
enable_eplb
,
indices_type_getter
=
indices_type_getter
,
)
if
use_grouped_topk
:
if
use_grouped_topk
:
assert
custom_routing_function
is
None
assert
custom_routing_function
is
None
if
num_expert_group
is
None
or
topk_group
is
None
:
if
num_expert_group
is
None
or
topk_group
is
None
:
...
...
vllm/model_executor/layers/fused_moe/router/zero_expert_router.py
0 → 100644
View file @
19ec9a0a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
import
torch
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
from
vllm.model_executor.layers.fused_moe.config
import
(
RoutingMethodType
,
get_routing_method_type
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
zero_experts_compute_triton
,
)
from
vllm.model_executor.layers.fused_moe.router.base_router
import
BaseRouter
from
vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router
import
(
fused_topk_bias
,
)
class
ZeroExpertRouter
(
BaseRouter
):
"""Router that handles zero expert computation as part of routing.
Routes over all experts (real + zero) using full e_score_correction_bias.
Computes zero expert identity contributions as a side effect during routing.
Remaps zero expert IDs to real expert ID 0 (with weight 0) so downstream
MoE computation can ignore them.
"""
def
__init__
(
self
,
top_k
:
int
,
global_num_experts
:
int
,
eplb_state
:
EplbLayerState
,
e_score_correction_bias
:
torch
.
Tensor
,
num_logical_experts
:
int
,
zero_expert_type
:
str
,
scoring_func
:
str
=
"softmax"
,
renormalize
:
bool
=
False
,
routed_scaling_factor
:
float
=
1.0
,
enable_eplb
:
bool
=
False
,
indices_type_getter
:
Callable
[[],
torch
.
dtype
|
None
]
|
None
=
None
,
):
super
().
__init__
(
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
eplb_state
=
eplb_state
,
enable_eplb
=
enable_eplb
,
indices_type_getter
=
indices_type_getter
,
)
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
num_logical_experts
=
num_logical_experts
self
.
zero_expert_type
=
zero_expert_type
self
.
scoring_func
=
scoring_func
self
.
renormalize
=
renormalize
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
_zero_expert_output
:
torch
.
Tensor
|
None
=
None
@
property
def
routing_method_type
(
self
)
->
RoutingMethodType
:
return
get_routing_method_type
(
scoring_func
=
self
.
scoring_func
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
num_expert_group
=
None
,
has_e_score_bias
=
True
,
)
def
_compute_routing
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
indices_type
:
torch
.
dtype
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute routing with full bias, compute zero expert output,
mask zero expert IDs."""
topk_weights
,
topk_ids
=
fused_topk_bias
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
e_score_correction_bias
=
self
.
e_score_correction_bias
.
data
,
topk
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
scoring_func
=
self
.
scoring_func
,
indices_type
=
indices_type
,
)
if
self
.
routed_scaling_factor
!=
1.0
:
topk_weights
*=
self
.
routed_scaling_factor
# Compute zero expert output using pre-EPLB topk_ids/weights.
# zero_experts_compute_triton modifies its inputs in-place, so
# pass clones.
self
.
_zero_expert_output
=
zero_experts_compute_triton
(
expert_indices
=
topk_ids
.
clone
(),
expert_scales
=
topk_weights
.
clone
(),
num_experts
=
self
.
num_logical_experts
,
zero_expert_type
=
self
.
zero_expert_type
,
hidden_states
=
hidden_states
,
)
# Mask zero expert entries: remap zero expert IDs to 0 with weight 0
# so downstream MoE computation ignores them.
zero_mask
=
topk_ids
>=
self
.
num_logical_experts
topk_ids
[
zero_mask
]
=
0
topk_weights
[
zero_mask
]
=
0.0
return
topk_weights
,
topk_ids
@
property
def
zero_expert_output
(
self
)
->
torch
.
Tensor
|
None
:
"""Retrieve and clear the zero expert output."""
output
=
self
.
_zero_expert_output
self
.
_zero_expert_output
=
None
return
output
vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py
View file @
19ec9a0a
...
@@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
...
@@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
from
vllm.model_executor.layers.fused_moe.router.fused_moe_router
import
(
from
vllm.model_executor.layers.fused_moe.router.fused_moe_router
import
(
FusedMoERouter
,
FusedMoERouter
,
)
)
from
vllm.model_executor.layers.fused_moe.router.zero_expert_router
import
(
ZeroExpertRouter
,
)
from
vllm.model_executor.layers.fused_moe.runner.moe_runner
import
MoERunner
from
vllm.model_executor.layers.fused_moe.runner.moe_runner
import
MoERunner
from
vllm.model_executor.layers.fused_moe.runner.shared_experts
import
(
from
vllm.model_executor.layers.fused_moe.runner.shared_experts
import
(
SharedExperts
,
SharedExperts
,
...
@@ -443,6 +446,19 @@ class MoERunnerBase(MoERunner):
...
@@ -443,6 +446,19 @@ class MoERunnerBase(MoERunner):
if
self
.
_shared_experts
is
not
None
:
if
self
.
_shared_experts
is
not
None
:
self
.
_shared_experts
.
maybe_sync_shared_experts_stream
(
shared_experts_input
)
self
.
_shared_experts
.
maybe_sync_shared_experts_stream
(
shared_experts_input
)
def
_maybe_add_zero_expert_output
(
self
,
result
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
isinstance
(
self
.
router
,
ZeroExpertRouter
):
zero_expert_output
=
self
.
router
.
zero_expert_output
assert
zero_expert_output
is
not
None
if
isinstance
(
result
,
tuple
):
result
=
(
result
[
0
],
result
[
1
]
+
zero_expert_output
)
else
:
result
=
result
+
zero_expert_output
return
result
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -494,7 +510,9 @@ class MoERunnerBase(MoERunner):
...
@@ -494,7 +510,9 @@ class MoERunnerBase(MoERunner):
self
.
_encode_layer_name
(),
self
.
_encode_layer_name
(),
)
)
return
self
.
_maybe_reduce_output
(
fused_output
,
og_hidden_dims
)
result
=
self
.
_maybe_reduce_output
(
fused_output
,
og_hidden_dims
)
return
self
.
_maybe_add_zero_expert_output
(
result
)
def
forward_dispatch
(
def
forward_dispatch
(
self
,
self
,
...
...
vllm/model_executor/layers/fused_moe/zero_expert_fused_moe.py
deleted
100644 → 0
View file @
1a9353bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
contextlib
import
contextmanager
import
torch
from
torch
import
nn
from
vllm.model_executor.layers.fused_moe.fused_moe
import
zero_experts_compute_triton
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
class
ZeroExpertFusedMoE
(
FusedMoE
):
"""
A FusedMoE operation that also computes the results of zero experts.
Zero experts perform identity operations (scaled pass-through) instead
of full MLP computations.
This class uses memoization to avoid redundant routing computation:
routing is computed once and reused for both zero expert computation
and the main FusedMoE forward pass.
"""
def
__init__
(
self
,
zero_expert_num
:
int
,
zero_expert_type
:
str
,
router
:
nn
.
Module
,
**
kwargs
,
):
# ZeroExpertFusedMoE manages its own custom_routing_function for memoization
assert
(
"custom_routing_function"
not
in
kwargs
or
kwargs
.
get
(
"custom_routing_function"
)
is
None
),
(
"ZeroExpertFusedMoE does not support external custom_routing_function. "
"It manages its own for routing memoization."
)
# Automatically slice router's e_score_correction_bias to only include
# real experts (not zero_experts) for the base FusedMoE.
# The full bias will be used temporarily in forward() for routing.
if
hasattr
(
router
,
"e_score_correction_bias"
)
and
"num_experts"
in
kwargs
:
num_real_experts
=
kwargs
[
"num_experts"
]
router_bias
=
router
.
e_score_correction_bias
user_bias
=
kwargs
.
get
(
"e_score_correction_bias"
)
# Use router's bias if:
# 1. User didn't provide bias, or
# 2. User provided full bias (same size as router)
if
user_bias
is
None
or
user_bias
.
shape
[
0
]
==
router_bias
.
shape
[
0
]:
kwargs
[
"e_score_correction_bias"
]
=
router_bias
[:
num_real_experts
]
# FusedMoE no longer accepts zero_expert_num/zero_expert_type.
# We handle zero experts ourselves in forward().
super
().
__init__
(
**
kwargs
)
# Store the actual zero_expert_num and zero_expert_type for our own use
self
.
_actual_zero_expert_num
=
zero_expert_num
self
.
_actual_zero_expert_type
=
zero_expert_type
self
.
_router
=
router
# Full router (includes zero experts)
# Expose zero_expert_num and zero_expert_type as attributes for
# compatibility with quantization methods that check these attributes
self
.
zero_expert_num
=
0
self
.
zero_expert_type
=
None
# Memoization state for routing results
self
.
_memoized_topk_weights
:
torch
.
Tensor
|
None
=
None
self
.
_memoized_topk_ids
:
torch
.
Tensor
|
None
=
None
# Create custom_routing_function to reuse memoized routing results
def
custom_routing_function
(
hidden_states
,
gating_output
,
topk
,
renormalize
):
"""Return memoized `topk_weights` and `topk_ids`."""
if
self
.
_memoized_topk_weights
is
None
or
self
.
_memoized_topk_ids
is
None
:
raise
RuntimeError
(
"ZeroExpertFusedMoE: routing results not memoized. "
"Call select_experts first to compute routing."
)
return
self
.
_memoized_topk_weights
,
self
.
_memoized_topk_ids
self
.
custom_routing_function
=
custom_routing_function
@
contextmanager
def
_temporarily_set_attrs
(
self
,
**
attrs
):
"""
Temporarily set attributes using object.__setattr__ and restore them.
This bypasses nn.Module.__setattr__ to avoid Dynamo tracing issues.
When PyTorch Dynamo traces the forward pass, it cannot handle
nn.Module.__setattr__ calls (which include parameter registration logic),
resulting in "Unsupported" errors. Using object.__setattr__ directly
sets the attribute without triggering nn.Module's custom __setattr__,
allowing Dynamo to trace the code successfully.
"""
originals
=
{
key
:
getattr
(
self
,
key
)
for
key
in
attrs
}
try
:
for
key
,
value
in
attrs
.
items
():
object
.
__setattr__
(
self
,
key
,
value
)
yield
finally
:
for
key
,
value
in
originals
.
items
():
object
.
__setattr__
(
self
,
key
,
value
)
def
_compute_zero_expert_result
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
"""Compute zero expert results using pre-computed routing."""
if
(
self
.
_actual_zero_expert_num
is
None
or
self
.
_actual_zero_expert_num
<=
0
or
self
.
_actual_zero_expert_type
is
None
):
return
None
return
zero_experts_compute_triton
(
expert_indices
=
topk_ids
.
clone
(),
expert_scales
=
topk_weights
.
clone
(),
num_experts
=
self
.
logical_num_experts
,
zero_expert_type
=
self
.
_actual_zero_expert_type
,
hidden_states
=
hidden_states
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
# Full logits including zero experts
)
->
torch
.
Tensor
:
"""
Forward pass with zero expert support and routing memoization.
Args:
hidden_states: Input hidden states
router_logits: Full router logits (including zero experts)
Returns:
Combined output from real experts and zero experts
"""
# Prepare temporary attribute overrides for routing computation
temp_attrs
=
{
"custom_routing_function"
:
None
,
# Disable for first routing
}
if
self
.
_router
is
not
None
:
temp_attrs
[
"e_score_correction_bias"
]
=
self
.
_router
.
e_score_correction_bias
# Compute routing with temporary attributes
# Pass full router_logits (including zero experts) so that zero experts
# can be properly identified in topk_ids
with
self
.
_temporarily_set_attrs
(
**
temp_attrs
):
topk_weights
,
topk_ids
=
self
.
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
# Full logits (includes zero experts)
)
# Compute zero expert result if needed
zero_expert_result
=
self
.
_compute_zero_expert_result
(
hidden_states
=
hidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
)
# Memoize routing results for reuse in super().forward()
self
.
_memoized_topk_weights
=
topk_weights
self
.
_memoized_topk_ids
=
topk_ids
# Slice router_logits for real experts only
router_logits_sliced
=
router_logits
[...,
:
self
.
logical_num_experts
]
# Compute real expert results (will reuse memoized routing via
# custom_routing_function)
# zero_expert_num is already 0, so FusedMoE won't handle zero experts
fused_out
=
super
().
forward
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits_sliced
,
)
# Combine results
# Both zero_expert_result and fused_out are computed from the same
# hidden_states, so they should be on the same device.
if
zero_expert_result
is
not
None
:
fused_out
=
fused_out
+
zero_expert_result
# Clear memoization after use
self
.
_memoized_topk_weights
=
None
self
.
_memoized_topk_ids
=
None
return
fused_out
vllm/model_executor/models/longcat_flash.py
View file @
19ec9a0a
...
@@ -46,7 +46,7 @@ from vllm.config import CacheConfig, VllmConfig
...
@@ -46,7 +46,7 @@ from vllm.config import CacheConfig, VllmConfig
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
ZeroExpertFusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -292,12 +292,10 @@ class LongcatMoe(nn.Module):
...
@@ -292,12 +292,10 @@ class LongcatMoe(nn.Module):
prefix
=
f
"
{
prefix
}
.gate"
,
prefix
=
f
"
{
prefix
}
.gate"
,
)
)
assert
config
.
zero_expert_num
is
not
None
assert
config
.
zero_expert_type
is
not
None
assert
config
.
zero_expert_type
is
not
None
self
.
experts
=
ZeroExpertFusedMoE
(
self
.
experts
=
FusedMoE
(
zero_expert_num
=
config
.
zero_expert_num
,
zero_expert_type
=
config
.
zero_expert_type
,
zero_expert_type
=
config
.
zero_expert_type
,
router
=
self
.
router
,
e_score_correction_bias
=
self
.
router
.
e_score_correction_bias
,
num_experts
=
num_experts
,
num_experts
=
num_experts
,
top_k
=
top_k
,
top_k
=
top_k
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
...
@@ -332,7 +330,7 @@ class LongcatMoe(nn.Module):
...
@@ -332,7 +330,7 @@ class LongcatMoe(nn.Module):
hidden_states_padded
.
to
(
self
.
router_params_dtype
)
hidden_states_padded
.
to
(
self
.
router_params_dtype
)
)
)
#
ZeroExpert
FusedMoE handles routing memoization and zero expert computation
# FusedMoE handles routing memoization and zero expert computation
# internally. Pass full router_logits (including zero experts) so that
# internally. Pass full router_logits (including zero experts) so that
# zero experts can be properly identified in routing.
# zero experts can be properly identified in routing.
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment