Unverified Commit 19ec9a0a authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Refactor ZeroExpertFusedMoE into new framework (#35549)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 1a9353bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for FusedMoE with zero experts.
Verifies that:
- The ZeroExpertRouter is properly created and used as the layer router.
- A forward pass through FusedMoE with zero experts produces correct output.
- The output decomposes correctly into real expert + zero expert contributions.
Note: tests generated with Claude.
"""
import pytest
import torch
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe.router.zero_expert_router import (
ZeroExpertRouter,
)
from vllm.v1.worker.workspace import init_workspace_manager
@pytest.fixture
def zero_expert_moe(dist_init, default_vllm_config):
"""Create a FusedMoE layer with zero experts."""
num_experts = 4
top_k = 2
# hidden_size must be >= 256 for the zero expert identity kernel to
# produce output (its BLOCK_SIZE=256 causes grid=0 when hidden_dim<256).
hidden_size = 256
intermediate_size = 512
zero_expert_num = 1
e_score_correction_bias = torch.zeros(
num_experts + zero_expert_num,
dtype=torch.float32,
device="cuda",
)
vllm_config = VllmConfig()
vllm_config.compilation_config.static_forward_context = dict()
with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
init_workspace_manager(torch.accelerator.current_device_index())
layer = FusedMoE(
zero_expert_type="identity",
e_score_correction_bias=e_score_correction_bias,
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=torch.bfloat16,
prefix="test_zero_expert_moe",
renormalize=False,
routed_scaling_factor=1.0,
scoring_func="softmax",
).cuda()
layer.quant_method.process_weights_after_loading(layer)
yield layer, vllm_config
@pytest.mark.parametrize("num_tokens", [1, 32])
def test_zero_expert_moe_router_is_zero_expert_router(zero_expert_moe, num_tokens):
"""Verify that FusedMoE with zero_expert_type creates a ZeroExpertRouter."""
layer, _ = zero_expert_moe
assert isinstance(layer.router, ZeroExpertRouter), (
f"Expected ZeroExpertRouter but got {type(layer.router).__name__}."
)
@pytest.mark.parametrize("num_tokens", [1, 32])
def test_zero_expert_moe_no_custom_routing_fn(zero_expert_moe, num_tokens):
"""Verify that custom_routing_function is not set (routing is handled
by ZeroExpertRouter, not a memoizing closure)."""
layer, _ = zero_expert_moe
assert layer.custom_routing_function is None
@pytest.mark.parametrize("num_tokens", [1, 32])
def test_zero_expert_moe_forward(zero_expert_moe, num_tokens):
"""Run a forward pass through FusedMoE with zero experts and verify output shape."""
layer, vllm_config = zero_expert_moe
hidden_size = layer.hidden_size
num_experts = 4
zero_expert_num = 1
total_experts = num_experts + zero_expert_num
hidden_states = torch.randn(
num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda"
)
router_logits = torch.randn(
num_tokens, total_experts, dtype=torch.float32, device="cuda"
)
# Initialize weights to small random values to avoid NaN from
# uninitialized memory.
with torch.no_grad():
for param in layer.parameters():
if param.dtype.is_floating_point:
param.normal_(0, 0.01)
with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
get_forward_context().all_moe_layers = None
output = layer.forward(hidden_states, router_logits)
assert output.shape == hidden_states.shape, (
f"Expected output shape {hidden_states.shape}, got {output.shape}"
)
assert output.dtype == hidden_states.dtype
assert not torch.isnan(output).any(), "Output contains NaN values"
@pytest.mark.parametrize("num_tokens", [1, 32])
def test_zero_expert_moe_output_decomposition(zero_expert_moe, num_tokens):
"""Validate that the FusedMoE output equals a plain FusedMoE
output (real experts only) plus the zero expert contribution.
The key invariant is:
zero_layer.forward(h, r_full) == plain_layer.forward(h, r_real)
+ zero_expert_output
We create a plain FusedMoE layer with the same weights and real-expert-only
router logits, compute the zero expert output via the ZeroExpertRouter, and
verify the sum matches the FusedMoE output.
"""
layer, vllm_config = zero_expert_moe
num_experts = 4
zero_expert_num = 1
total_experts = num_experts + zero_expert_num
hidden_states = torch.randn(
num_tokens, layer.hidden_size, dtype=torch.bfloat16, device="cuda"
)
router_logits = torch.randn(
num_tokens, total_experts, dtype=torch.float32, device="cuda"
)
with torch.no_grad():
for param in layer.parameters():
if param.dtype.is_floating_point:
param.normal_(0, 0.01)
with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
get_forward_context().all_moe_layers = None
# Create a plain FusedMoE layer with the same config but no zero
# experts. Use a separate prefix to avoid collision.
plain_layer = FusedMoE(
num_experts=num_experts,
top_k=layer.top_k,
hidden_size=layer.hidden_size,
intermediate_size=layer.intermediate_size_per_partition,
params_dtype=torch.bfloat16,
prefix="test_zero_expert_moe_plain",
renormalize=False,
scoring_func="softmax",
e_score_correction_bias=layer.e_score_correction_bias,
).cuda()
# Share weights from the zero expert layer.
plain_layer.w13_weight.data.copy_(layer.w13_weight.data)
plain_layer.w2_weight.data.copy_(layer.w2_weight.data)
plain_layer.quant_method.process_weights_after_loading(plain_layer)
# Compute routing via the ZeroExpertRouter. This produces masked
# topk_weights/topk_ids (zero expert entries have weight=0, id=0)
# and stores zero_expert_output as a side effect.
topk_weights, topk_ids = layer.router.select_experts(
hidden_states, router_logits
)
zero_output = layer.router.zero_expert_output
# Compute real expert output using the plain layer with the masked
# routing from the ZeroExpertRouter.
real_output = plain_layer.quant_method.apply(
layer=plain_layer,
x=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=None,
)
# Get the combined output from the zero expert layer.
full_output = layer.forward(hidden_states, router_logits)
assert zero_output is not None, "Zero expert output should not be None"
assert not torch.isnan(real_output).any(), "Real expert output has NaN"
assert not torch.isnan(zero_output).any(), "Zero expert output has NaN"
assert not torch.isnan(full_output).any(), "Full output has NaN"
expected = real_output + zero_output
torch.testing.assert_close(
full_output,
expected,
atol=0,
rtol=0,
msg="FusedMoE output should equal plain FusedMoE output "
"plus zero expert contribution",
)
@pytest.mark.parametrize("num_tokens", [1, 32])
def test_zero_expert_moe_zero_expert_is_identity(zero_expert_moe, num_tokens):
"""Validate zero expert identity behavior.
When routing strongly favors the zero expert, its contribution should
be a scaled version of hidden_states (identity operation). We verify
this by manually computing the expected zero expert output from the
routing weights and comparing against what the router produces.
"""
layer, vllm_config = zero_expert_moe
num_experts = 4
zero_expert_num = 1
total_experts = num_experts + zero_expert_num
hidden_states = torch.randn(
num_tokens, layer.hidden_size, dtype=torch.bfloat16, device="cuda"
)
# Strongly bias toward the zero expert (index 4).
router_logits = torch.full(
(num_tokens, total_experts), -10.0, dtype=torch.float32, device="cuda"
)
router_logits[:, num_experts] = 10.0 # zero expert gets high logit
with torch.no_grad():
for param in layer.parameters():
if param.dtype.is_floating_point:
param.normal_(0, 0.01)
with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
get_forward_context().all_moe_layers = None
# Run routing to get topk_weights/topk_ids before masking.
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=layer.router.e_score_correction_bias.data,
topk=layer.top_k,
renormalize=layer.router.renormalize,
scoring_func=layer.router.scoring_func,
)
# Manually compute expected zero expert identity output:
# For each token, sum routing weights assigned to zero expert slots,
# then multiply by hidden_states.
zero_mask = topk_ids >= num_experts
zero_weight_per_token = (topk_weights * zero_mask.float()).sum(
dim=-1, keepdim=True
)
expected_zero_output = (hidden_states.float() * zero_weight_per_token).to(
hidden_states.dtype
)
# Run routing directly to trigger zero expert computation
# without going through the runner (which consumes the output).
layer.router.select_experts(hidden_states, router_logits)
actual_zero_output = layer.router.zero_expert_output
assert actual_zero_output is not None
assert zero_mask.any(), (
"With high zero expert logit, at least some slots should route "
"to the zero expert"
)
torch.testing.assert_close(
actual_zero_output,
expected_zero_output,
atol=1e-3,
rtol=1e-3,
msg="Zero expert identity output should equal "
"hidden_states * sum(zero_expert_weights)",
)
...@@ -33,9 +33,6 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE ...@@ -33,9 +33,6 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod, UnquantizedFusedMoEMethod,
) )
from vllm.model_executor.layers.fused_moe.zero_expert_fused_moe import (
ZeroExpertFusedMoE,
)
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
_config: dict[str, Any] | None = None _config: dict[str, Any] | None = None
...@@ -68,7 +65,6 @@ __all__ = [ ...@@ -68,7 +65,6 @@ __all__ = [
"GateLinear", "GateLinear",
"RoutingMethodType", "RoutingMethodType",
"SharedFusedMoE", "SharedFusedMoE",
"ZeroExpertFusedMoE",
"activation_without_mul", "activation_without_mul",
"apply_moe_activation", "apply_moe_activation",
"override_config", "override_config",
......
...@@ -274,6 +274,7 @@ class FusedMoE(PluggableLayer): ...@@ -274,6 +274,7 @@ class FusedMoE(PluggableLayer):
gate: torch.nn.Module | None = None, gate: torch.nn.Module | None = None,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
routed_input_transform: torch.nn.Module | None = None, routed_input_transform: torch.nn.Module | None = None,
zero_expert_type: str | None = None,
): ):
super().__init__() super().__init__()
...@@ -462,6 +463,8 @@ class FusedMoE(PluggableLayer): ...@@ -462,6 +463,8 @@ class FusedMoE(PluggableLayer):
# TODO(bnell): once we can construct the MK at init time, we # TODO(bnell): once we can construct the MK at init time, we
# can make this a value. # can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype, indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
zero_expert_type=zero_expert_type,
num_logical_experts=self.logical_num_experts,
) )
self.routing_method_type: RoutingMethodType = self.router.routing_method_type self.routing_method_type: RoutingMethodType = self.router.routing_method_type
......
...@@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import ( ...@@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import ( from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import (
RoutingSimulatorRouter, RoutingSimulatorRouter,
) )
from vllm.model_executor.layers.fused_moe.router.zero_expert_router import (
ZeroExpertRouter,
)
EMPTY_EPLB_STATE: EplbLayerState = EplbLayerState() EMPTY_EPLB_STATE: EplbLayerState = EplbLayerState()
...@@ -49,6 +52,9 @@ def create_fused_moe_router( ...@@ -49,6 +52,9 @@ def create_fused_moe_router(
# eplb parameters # eplb parameters
enable_eplb: bool = False, enable_eplb: bool = False,
eplb_state: EplbLayerState = EMPTY_EPLB_STATE, eplb_state: EplbLayerState = EMPTY_EPLB_STATE,
# zero expert parameters
zero_expert_type: str | None = None,
num_logical_experts: int | None = None,
) -> FusedMoERouter: ) -> FusedMoERouter:
""" """
Factory function to create the appropriate FusedMoERouter subclass based on Factory function to create the appropriate FusedMoERouter subclass based on
...@@ -56,10 +62,11 @@ def create_fused_moe_router( ...@@ -56,10 +62,11 @@ def create_fused_moe_router(
The selection logic follows this priority order: The selection logic follows this priority order:
1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set 1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set
2. GroupedTopKRouter - if use_grouped_topk is True 2. ZeroExpertRouter - if zero_expert_type is not None
3. CustomRoutingRouter - if custom_routing_function is not None 3. GroupedTopKRouter - if use_grouped_topk is True
4. FusedTopKBiasRouter - if e_score_correction_bias is not None 4. CustomRoutingRouter - if custom_routing_function is not None
5. FusedTopKRouter - default fallback 5. FusedTopKBiasRouter - if e_score_correction_bias is not None
6. FusedTopKRouter - default fallback
Common arguments: Common arguments:
top_k: Number of experts to select per token top_k: Number of experts to select per token
...@@ -86,6 +93,12 @@ def create_fused_moe_router( ...@@ -86,6 +93,12 @@ def create_fused_moe_router(
enable_eplb: Whether EPLB is enabled enable_eplb: Whether EPLB is enabled
eplb_state: EPLB (Expert Parallelism Load Balancing) state eplb_state: EPLB (Expert Parallelism Load Balancing) state
Zero expert arguments:
zero_expert_type: Type of zero expert (e.g. identity). If not None,
creates a ZeroExpertRouter.
num_logical_experts: Number of real (non-zero) experts. Required when
zero_expert_type is not None.
Returns: Returns:
An instance of the appropriate FusedMoERouter subclass An instance of the appropriate FusedMoERouter subclass
""" """
...@@ -100,6 +113,27 @@ def create_fused_moe_router( ...@@ -100,6 +113,27 @@ def create_fused_moe_router(
indices_type_getter=indices_type_getter, indices_type_getter=indices_type_getter,
) )
if zero_expert_type is not None:
assert num_logical_experts is not None, (
"num_logical_experts is required when zero_expert_type is set"
)
assert e_score_correction_bias is not None, (
"e_score_correction_bias is required when zero_expert_type is set"
)
return ZeroExpertRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
e_score_correction_bias=e_score_correction_bias,
num_logical_experts=num_logical_experts,
zero_expert_type=zero_expert_type,
scoring_func=scoring_func,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
if use_grouped_topk: if use_grouped_topk:
assert custom_routing_function is None assert custom_routing_function is None
if num_expert_group is None or topk_group is None: if num_expert_group is None or topk_group is None:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
get_routing_method_type,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
zero_experts_compute_triton,
)
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
class ZeroExpertRouter(BaseRouter):
"""Router that handles zero expert computation as part of routing.
Routes over all experts (real + zero) using full e_score_correction_bias.
Computes zero expert identity contributions as a side effect during routing.
Remaps zero expert IDs to real expert ID 0 (with weight 0) so downstream
MoE computation can ignore them.
"""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
e_score_correction_bias: torch.Tensor,
num_logical_experts: int,
zero_expert_type: str,
scoring_func: str = "softmax",
renormalize: bool = False,
routed_scaling_factor: float = 1.0,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.e_score_correction_bias = e_score_correction_bias
self.num_logical_experts = num_logical_experts
self.zero_expert_type = zero_expert_type
self.scoring_func = scoring_func
self.renormalize = renormalize
self.routed_scaling_factor = routed_scaling_factor
self._zero_expert_output: torch.Tensor | None = None
@property
def routing_method_type(self) -> RoutingMethodType:
return get_routing_method_type(
scoring_func=self.scoring_func,
top_k=self.top_k,
renormalize=self.renormalize,
num_expert_group=None,
has_e_score_bias=True,
)
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing with full bias, compute zero expert output,
mask zero expert IDs."""
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
scoring_func=self.scoring_func,
indices_type=indices_type,
)
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
# Compute zero expert output using pre-EPLB topk_ids/weights.
# zero_experts_compute_triton modifies its inputs in-place, so
# pass clones.
self._zero_expert_output = zero_experts_compute_triton(
expert_indices=topk_ids.clone(),
expert_scales=topk_weights.clone(),
num_experts=self.num_logical_experts,
zero_expert_type=self.zero_expert_type,
hidden_states=hidden_states,
)
# Mask zero expert entries: remap zero expert IDs to 0 with weight 0
# so downstream MoE computation ignores them.
zero_mask = topk_ids >= self.num_logical_experts
topk_ids[zero_mask] = 0
topk_weights[zero_mask] = 0.0
return topk_weights, topk_ids
@property
def zero_expert_output(self) -> torch.Tensor | None:
"""Retrieve and clear the zero expert output."""
output = self._zero_expert_output
self._zero_expert_output = None
return output
...@@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( ...@@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter, FusedMoERouter,
) )
from vllm.model_executor.layers.fused_moe.router.zero_expert_router import (
ZeroExpertRouter,
)
from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts, SharedExperts,
...@@ -443,6 +446,19 @@ class MoERunnerBase(MoERunner): ...@@ -443,6 +446,19 @@ class MoERunnerBase(MoERunner):
if self._shared_experts is not None: if self._shared_experts is not None:
self._shared_experts.maybe_sync_shared_experts_stream(shared_experts_input) self._shared_experts.maybe_sync_shared_experts_stream(shared_experts_input)
def _maybe_add_zero_expert_output(
self,
result: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if isinstance(self.router, ZeroExpertRouter):
zero_expert_output = self.router.zero_expert_output
assert zero_expert_output is not None
if isinstance(result, tuple):
result = (result[0], result[1] + zero_expert_output)
else:
result = result + zero_expert_output
return result
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -494,7 +510,9 @@ class MoERunnerBase(MoERunner): ...@@ -494,7 +510,9 @@ class MoERunnerBase(MoERunner):
self._encode_layer_name(), self._encode_layer_name(),
) )
return self._maybe_reduce_output(fused_output, og_hidden_dims) result = self._maybe_reduce_output(fused_output, og_hidden_dims)
return self._maybe_add_zero_expert_output(result)
def forward_dispatch( def forward_dispatch(
self, self,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
import torch
from torch import nn
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
class ZeroExpertFusedMoE(FusedMoE):
"""
A FusedMoE operation that also computes the results of zero experts.
Zero experts perform identity operations (scaled pass-through) instead
of full MLP computations.
This class uses memoization to avoid redundant routing computation:
routing is computed once and reused for both zero expert computation
and the main FusedMoE forward pass.
"""
def __init__(
self,
zero_expert_num: int,
zero_expert_type: str,
router: nn.Module,
**kwargs,
):
# ZeroExpertFusedMoE manages its own custom_routing_function for memoization
assert (
"custom_routing_function" not in kwargs
or kwargs.get("custom_routing_function") is None
), (
"ZeroExpertFusedMoE does not support external custom_routing_function. "
"It manages its own for routing memoization."
)
# Automatically slice router's e_score_correction_bias to only include
# real experts (not zero_experts) for the base FusedMoE.
# The full bias will be used temporarily in forward() for routing.
if hasattr(router, "e_score_correction_bias") and "num_experts" in kwargs:
num_real_experts = kwargs["num_experts"]
router_bias = router.e_score_correction_bias
user_bias = kwargs.get("e_score_correction_bias")
# Use router's bias if:
# 1. User didn't provide bias, or
# 2. User provided full bias (same size as router)
if user_bias is None or user_bias.shape[0] == router_bias.shape[0]:
kwargs["e_score_correction_bias"] = router_bias[:num_real_experts]
# FusedMoE no longer accepts zero_expert_num/zero_expert_type.
# We handle zero experts ourselves in forward().
super().__init__(**kwargs)
# Store the actual zero_expert_num and zero_expert_type for our own use
self._actual_zero_expert_num = zero_expert_num
self._actual_zero_expert_type = zero_expert_type
self._router = router # Full router (includes zero experts)
# Expose zero_expert_num and zero_expert_type as attributes for
# compatibility with quantization methods that check these attributes
self.zero_expert_num = 0
self.zero_expert_type = None
# Memoization state for routing results
self._memoized_topk_weights: torch.Tensor | None = None
self._memoized_topk_ids: torch.Tensor | None = None
# Create custom_routing_function to reuse memoized routing results
def custom_routing_function(hidden_states, gating_output, topk, renormalize):
"""Return memoized `topk_weights` and `topk_ids`."""
if self._memoized_topk_weights is None or self._memoized_topk_ids is None:
raise RuntimeError(
"ZeroExpertFusedMoE: routing results not memoized. "
"Call select_experts first to compute routing."
)
return self._memoized_topk_weights, self._memoized_topk_ids
self.custom_routing_function = custom_routing_function
@contextmanager
def _temporarily_set_attrs(self, **attrs):
"""
Temporarily set attributes using object.__setattr__ and restore them.
This bypasses nn.Module.__setattr__ to avoid Dynamo tracing issues.
When PyTorch Dynamo traces the forward pass, it cannot handle
nn.Module.__setattr__ calls (which include parameter registration logic),
resulting in "Unsupported" errors. Using object.__setattr__ directly
sets the attribute without triggering nn.Module's custom __setattr__,
allowing Dynamo to trace the code successfully.
"""
originals = {key: getattr(self, key) for key in attrs}
try:
for key, value in attrs.items():
object.__setattr__(self, key, value)
yield
finally:
for key, value in originals.items():
object.__setattr__(self, key, value)
def _compute_zero_expert_result(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | None:
"""Compute zero expert results using pre-computed routing."""
if (
self._actual_zero_expert_num is None
or self._actual_zero_expert_num <= 0
or self._actual_zero_expert_type is None
):
return None
return zero_experts_compute_triton(
expert_indices=topk_ids.clone(),
expert_scales=topk_weights.clone(),
num_experts=self.logical_num_experts,
zero_expert_type=self._actual_zero_expert_type,
hidden_states=hidden_states,
)
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor, # Full logits including zero experts
) -> torch.Tensor:
"""
Forward pass with zero expert support and routing memoization.
Args:
hidden_states: Input hidden states
router_logits: Full router logits (including zero experts)
Returns:
Combined output from real experts and zero experts
"""
# Prepare temporary attribute overrides for routing computation
temp_attrs = {
"custom_routing_function": None, # Disable for first routing
}
if self._router is not None:
temp_attrs["e_score_correction_bias"] = self._router.e_score_correction_bias
# Compute routing with temporary attributes
# Pass full router_logits (including zero experts) so that zero experts
# can be properly identified in topk_ids
with self._temporarily_set_attrs(**temp_attrs):
topk_weights, topk_ids = self.select_experts(
hidden_states=hidden_states,
router_logits=router_logits, # Full logits (includes zero experts)
)
# Compute zero expert result if needed
zero_expert_result = self._compute_zero_expert_result(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
# Memoize routing results for reuse in super().forward()
self._memoized_topk_weights = topk_weights
self._memoized_topk_ids = topk_ids
# Slice router_logits for real experts only
router_logits_sliced = router_logits[..., : self.logical_num_experts]
# Compute real expert results (will reuse memoized routing via
# custom_routing_function)
# zero_expert_num is already 0, so FusedMoE won't handle zero experts
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits_sliced,
)
# Combine results
# Both zero_expert_result and fused_out are computed from the same
# hidden_states, so they should be on the same device.
if zero_expert_result is not None:
fused_out = fused_out + zero_expert_result
# Clear memoization after use
self._memoized_topk_weights = None
self._memoized_topk_ids = None
return fused_out
...@@ -46,7 +46,7 @@ from vllm.config import CacheConfig, VllmConfig ...@@ -46,7 +46,7 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE, ZeroExpertFusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -292,12 +292,10 @@ class LongcatMoe(nn.Module): ...@@ -292,12 +292,10 @@ class LongcatMoe(nn.Module):
prefix=f"{prefix}.gate", prefix=f"{prefix}.gate",
) )
assert config.zero_expert_num is not None
assert config.zero_expert_type is not None assert config.zero_expert_type is not None
self.experts = ZeroExpertFusedMoE( self.experts = FusedMoE(
zero_expert_num=config.zero_expert_num,
zero_expert_type=config.zero_expert_type, zero_expert_type=config.zero_expert_type,
router=self.router, e_score_correction_bias=self.router.e_score_correction_bias,
num_experts=num_experts, num_experts=num_experts,
top_k=top_k, top_k=top_k,
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -332,7 +330,7 @@ class LongcatMoe(nn.Module): ...@@ -332,7 +330,7 @@ class LongcatMoe(nn.Module):
hidden_states_padded.to(self.router_params_dtype) hidden_states_padded.to(self.router_params_dtype)
) )
# ZeroExpertFusedMoE handles routing memoization and zero expert computation # FusedMoE handles routing memoization and zero expert computation
# internally. Pass full router_logits (including zero experts) so that # internally. Pass full router_logits (including zero experts) so that
# zero experts can be properly identified in routing. # zero experts can be properly identified in routing.
final_hidden_states = self.experts( final_hidden_states = self.experts(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment