Unverified Commit 327a02d8 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Separate Router into OO Classes (#30623)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 2f03035a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Callable
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> torch.Tensor:
"""
Map the logical expert ids to physical expert ids
and record the expert load metrics.
This will select a pseudo-random replica for each logical expert.
Only used for EPLB.
Args:
topk_ids: The logical expert ids.
expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
Returns:
The physical expert ids.
"""
# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long = topk_ids.long()
# Use (token position) modulo (replica count)
# to deterministically choose a replica
replica_count = logical_replica_count[topk_ids_long]
# Flatten-position based index, reshaped back to `topk_ids` shape
pos_indices = torch.arange(
topk_ids.numel(), device=topk_ids.device, dtype=torch.long
).reshape_as(topk_ids)
# Compute pseudo-random indices by modulo
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
physical_ids = (
logical_to_physical_map[topk_ids_long]
.gather(-1, replica_indices)
.squeeze(-1)
)
topk_ids = physical_ids
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
# `expert_load_view`: (num_physical_experts,)
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten = topk_ids.flatten()
expert_load_view.scatter_add_(
dim=0,
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
)
return topk_ids
else:
def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> torch.Tensor:
# CPU fallback: no EPLB so just return as is
return topk_ids
class BaseRouter(FusedMoERouter):
"""
Base router class that provides common functionality for all router implementations.
This class implements the template method pattern where select_experts() handles
common pre-processing and post-processing, delegating the actual routing logic
to the abstract _compute_routing() method.
"""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
enable_eplb: bool = False,
# TODO(bnell): Once the MK is constructed at layer init time, we
# can make this a plain value instead of a callback.
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
"""
Note: the indices dtype might not be available at router construction
time, so we need to supply a callback to get it at runtime. This is
because the indices type is supplied by modular kernels which are
created after MoE layer/router construction.
"""
super().__init__()
self.top_k = top_k
self.global_num_experts = global_num_experts
self.eplb_state = eplb_state
self.enable_eplb = enable_eplb
self.indices_type_getter = indices_type_getter
self.capture: Callable[[torch.tensor], None] | None = None
def _validate_eplb_state(self) -> None:
"""Validate that EPLB state is properly initialized if EPLB is enabled."""
if self.enable_eplb:
if self.eplb_state.expert_load_view is None:
raise ValueError("enable_eplb=True requires expert_load_view != None")
if self.eplb_state.logical_to_physical_map is None:
raise ValueError(
"enable_eplb=True requires logical_to_physical_map != None"
)
if self.eplb_state.logical_replica_count is None:
raise ValueError(
"enable_eplb=True requires logical_replica_count != None"
)
def _get_indices_type(self) -> torch.dtype | None:
"""Get the desired indices dtype from the getter function."""
return (
self.indices_type_getter() if self.indices_type_getter is not None else None
)
def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor:
"""Apply EPLB mapping to convert logical expert IDs to physical expert IDs."""
if self.enable_eplb:
assert self.eplb_state.expert_load_view is not None
assert self.eplb_state.logical_to_physical_map is not None
assert self.eplb_state.logical_replica_count is not None
return eplb_map_to_physical_and_record(
topk_ids=topk_ids,
expert_load_view=self.eplb_state.expert_load_view,
logical_to_physical_map=self.eplb_state.logical_to_physical_map,
logical_replica_count=self.eplb_state.logical_replica_count,
)
return topk_ids
def _convert_indices_dtype(
self, topk_ids: torch.Tensor, indices_type: torch.dtype | None
) -> torch.Tensor:
"""Convert topk_ids to the desired dtype if needed."""
if (indices_type is not None) and topk_ids.dtype != indices_type:
topk_ids = topk_ids.to(dtype=indices_type)
assert topk_ids.dtype == indices_type or indices_type is None
return topk_ids
@abstractmethod
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute the actual routing logic.
This method must be implemented by subclasses to provide the specific
routing algorithm (e.g., grouped_topk, fused_topk, custom routing, etc.).
Args:
hidden_states: Input hidden states
router_logits: Router logits for expert selection
indices_type: Desired dtype for expert indices (may be None)
Returns:
tuple of (topk_weights, topk_ids)
"""
raise NotImplementedError
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
This method implements the template method pattern:
1. Validates EPLB state
2. Gets indices type
3. Calls _compute_routing() to get topk_weights and topk_ids
4. Applies EPLB mapping if enabled
5. Converts indices dtype if needed
Returns:
(topk_weights, topk_ids)
(tuple[torch.Tensor, torch.Tensor]):
The weights and expert ids computation result.
**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
# Step 1: Validate EPLB state
self._validate_eplb_state()
# Step 2: Get indices type.
indices_type = self._get_indices_type()
# Step 3: Compute routing (delegated to subclass)
topk_weights, topk_ids = self._compute_routing(
hidden_states, router_logits, indices_type
)
# Step 4: Apply EPLB mapping
topk_ids = self._apply_eplb_mapping(topk_ids)
# Step 5: Convert indices dtype
topk_ids = self._convert_indices_dtype(topk_ids, indices_type)
# TODO(bnell): temporary hack until select_experts is moved into FusedMoE
if self.capture is not None:
self.capture(topk_ids)
return topk_weights, topk_ids
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
class CustomRoutingRouter(BaseRouter):
"""Router using a custom user-provided routing function."""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
custom_routing_function: Callable,
renormalize: bool = True,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.custom_routing_function = custom_routing_function
self.renormalize = renormalize
@property
def routing_method_type(self) -> RoutingMethodType:
return RoutingMethodType.Custom
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using the custom routing function."""
topk_weights, topk_ids = self.custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
)
return topk_weights.to(torch.float32), topk_ids
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
):
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
class FusedTopKBiasRouter(BaseRouter):
"""Router using fused top-k with e_score_correction_bias."""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
e_score_correction_bias: torch.Tensor,
renormalize: bool = True,
routed_scaling_factor: float = 1.0,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.e_score_correction_bias = e_score_correction_bias
self.renormalize = renormalize
self.routed_scaling_factor = routed_scaling_factor
@property
def routing_method_type(self) -> RoutingMethodType:
return (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
)
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using fused top-k with bias."""
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
)
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
return topk_weights, topk_ids
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import vllm._custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
def vllm_topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
return topk_weights, topk_indices
def dispatch_topk_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
return rocm_aiter_ops.topk_softmax
return vllm_topk_softmax
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
indices_type: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
M, _ = hidden_states.size()
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device,
)
token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices
class FusedTopKRouter(BaseRouter):
"""Default router using standard fused top-k routing."""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
scoring_func: str = "softmax",
renormalize: bool = True,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
assert scoring_func == "softmax", "FusedTopKRouter only supports softmax."
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.renormalize = renormalize
@property
def routing_method_type(self) -> RoutingMethodType:
return (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
)
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using standard fused top-k."""
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
indices_type=indices_type,
)
return topk_weights, topk_ids
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from functools import partial
import torch
from vllm import _custom_ops as ops
from vllm import envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_grouped_topk,
)
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform
def fused_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "sigmoid":
# Fully fused kernel path for sigmoid
topk_values, topk_indices = ops.grouped_topk(
gating_output, # raw logits
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
1, # scoring_func=1 for sigmoid
)
elif scoring_func == "softmax":
# Apply softmax in Python, then use fused kernel
# TODO: Add support for softmax in kernel
scores = torch.softmax(gating_output, dim=-1)
topk_values, topk_indices = ops.grouped_topk(
scores, # pre-computed scores
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
0, # scoring_func=0 (no activation, scores already computed)
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# Fused kernel outputs float32 values and int32 indices directly
return topk_values, topk_indices
# This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(
dynamic=True,
backend=current_platform.simple_compile_backend,
options=maybe_disable_graph_partition(current_platform.simple_compile_backend),
)
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if (
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
and current_platform.is_cuda()
and num_expert_group <= 32
and topk <= 32
and e_score_correction_bias is not None
):
return fused_grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
)
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.size(0)
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=use_sorted
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
# --8<-- [start:grouped_topk]
@CustomOp.register("grouped_topk")
class GroupedTopk(CustomOp):
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
# --8<-- [end:grouped_topk]
def __init__(
self,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
num_fused_shared_experts: int = 0,
) -> None:
super().__init__()
self.native_impl = grouped_topk
self.topk = topk
self.renormalize = renormalize
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.num_fused_shared_experts = num_fused_shared_experts
def forward_native(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.native_impl(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
def forward_hip(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert self.num_fused_shared_experts == 0
return rocm_aiter_grouped_topk(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
self.num_fused_shared_experts,
)
else:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
class GroupedTopKRouter(BaseRouter):
"""Router using grouped top-k routing (e.g., DeepSeekV2/V3)."""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
num_expert_group: int,
topk_group: int,
renormalize: bool = True,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
num_fused_shared_experts: int = 0,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
routing_method_type: RoutingMethodType | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.renormalize = renormalize
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.e_score_correction_bias = e_score_correction_bias
self.num_fused_shared_experts = num_fused_shared_experts
# Determine routing method type
if routing_method_type is not None:
self._routing_method_type = routing_method_type
elif scoring_func == "sigmoid":
self._routing_method_type = RoutingMethodType.DeepSeekV3
else:
self._routing_method_type = RoutingMethodType.TopK
@property
def routing_method_type(self) -> RoutingMethodType:
return self._routing_method_type
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using grouped top-k."""
def valid_grouping() -> bool:
# Check if num_experts is greater than num_expert_group
# and is divisible by num_expert_group
num_experts = router_logits.shape[-1]
if num_experts <= self.num_expert_group:
return False
return num_experts % self.num_expert_group == 0
if not valid_grouping():
if self.e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
)
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
else:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
indices_type=indices_type,
)
return topk_weights, topk_ids
# Select grouped_topk implementation
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert self.num_fused_shared_experts == 0
grouped_topk_impl = partial(
rocm_aiter_grouped_topk,
num_fused_shared_experts=self.num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
)
return topk_weights, topk_ids
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import vllm.envs as envs
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
from vllm.model_executor.layers.fused_moe.router.custom_routing_router import (
CustomRoutingRouter,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
FusedTopKBiasRouter,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
FusedTopKRouter,
)
from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
GroupedTopKRouter,
)
from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import (
RoutingSimulatorRouter,
)
EMPTY_EPLB_STATE: EplbLayerState = EplbLayerState()
def create_fused_moe_router(
# common parameters
top_k: int,
global_num_experts: int,
renormalize: bool = True,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
routing_method_type: RoutingMethodType | None = None,
# grouped topk parameters
use_grouped_topk: bool = False,
num_expert_group: int | None = None,
topk_group: int | None = None,
scoring_func: str = "softmax",
num_fused_shared_experts: int = 0,
# grouped topk + fused topk bias parameters
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
# custom routing paramaters
custom_routing_function: Callable | None = None,
# eplb parameters
enable_eplb: bool = False,
eplb_state: EplbLayerState = EMPTY_EPLB_STATE,
capture: Callable[[torch.tensor], None] | None = None,
) -> FusedMoERouter:
"""
Factory function to create the appropriate FusedMoERouter subclass based on
the provided parameters.
The selection logic follows this priority order:
1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set
2. GroupedTopKRouter - if use_grouped_topk is True
3. CustomRoutingRouter - if custom_routing_function is not None
4. FusedTopKBiasRouter - if e_score_correction_bias is not None
5. FusedTopKRouter - default fallback
Common arguments:
top_k: Number of experts to select per token
global_num_experts: Total number of experts in the model
renormalize: Whether to renormalize the routing weights
indices_type_getter: Function to get the desired indices dtype
routing_method_type: Optional explicit routing method type
Grouped topk arguments:
use_grouped_topk: Whether to use grouped top-k routing
num_expert_group: Number of expert groups (for grouped routing)
topk_group: Top-k within each group (for grouped routing)
scoring_func: Scoring function to use ("softmax" or "sigmoid")
num_fused_shared_experts: Number of fused shared experts (for ROCm AITER)
Grouped topk and fused topk bias arguments:
routed_scaling_factor: Scaling factor for routed weights
e_score_correction_bias: Optional bias correction for expert scores
Custom routing arguments:
custom_routing_function: Optional custom routing function
EPLB arguments:
enable_eplb: Whether EPLB is enabled
eplb_state: EPLB (Expert Parallelism Load Balancing) state
Returns:
An instance of the appropriate FusedMoERouter subclass
"""
router: BaseRouter
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
router = RoutingSimulatorRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
# TODO(bnell): this is temporary until select_experts is
# separated from apply.
router.capture = capture
return router
if use_grouped_topk:
assert custom_routing_function is None
if num_expert_group is None or topk_group is None:
raise ValueError(
"num_expert_group and topk_group must be provided when "
"use_grouped_topk is True"
)
router = GroupedTopKRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
num_expert_group=num_expert_group,
topk_group=topk_group,
renormalize=renormalize,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=num_fused_shared_experts,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
routing_method_type=routing_method_type,
)
router.capture = capture
return router
if custom_routing_function is not None:
router = CustomRoutingRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
custom_routing_function=custom_routing_function,
renormalize=renormalize,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
router.capture = capture
return router
if scoring_func != "softmax":
raise ValueError(
"Only softmax scoring function is supported for non-grouped topk."
)
if e_score_correction_bias is not None:
router = FusedTopKBiasRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
e_score_correction_bias=e_score_correction_bias,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
router.capture = capture
return router
router = FusedTopKRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
renormalize=renormalize,
scoring_func=scoring_func,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
router.capture = capture
return router
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Token-to-Expert Routing Simulator
This module provides a framework for simulating and testing different
token-to-expert routing strategies for Mixture of Experts (MoE) models.
It supports routing logic customization and includes example implementations
like uniform random routing.
"""
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
import torch
import vllm.envs as envs
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
logger = init_logger(__name__)
......@@ -308,3 +304,44 @@ class RoutingSimulator:
top_k=top_k,
indices_type=indices_type,
)
class RoutingSimulatorRouter(BaseRouter):
"""Router that uses routing simulation strategies for testing/debugging."""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
@property
def routing_method_type(self) -> RoutingMethodType:
return RoutingMethodType.Simulated
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Use routing simulator to compute routing."""
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=routing_strategy,
top_k=self.top_k,
indices_type=indices_type,
)
return topk_weights, topk_ids
......@@ -20,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
FusedMoEPermuteExpertsUnpermute,
......@@ -32,6 +31,9 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
make_unquantized_moe_kernel,
select_unquantized_moe_backend,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
......@@ -312,9 +314,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
layer.enable_eplb is not False
or layer.expert_load_view is not None
or layer.logical_to_physical_map is not None
or layer.logical_replica_count is not None
or layer.eplb_state.expert_load_view is not None
or layer.eplb_state.logical_to_physical_map is not None
or layer.eplb_state.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for CPU.")
......@@ -346,9 +348,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
layer.enable_eplb is not False
or layer.expert_load_view is not None
or layer.logical_to_physical_map is not None
or layer.logical_replica_count is not None
or layer.eplb_state.expert_load_view is not None
or layer.eplb_state.logical_to_physical_map is not None
or layer.eplb_state.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for XPU.")
return layer.ipex_fusion(
......
......@@ -10,12 +10,12 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
......
......@@ -6,11 +6,11 @@ from typing import Any, Union
import torch
from packaging import version
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
......
......@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig,
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoERouter,
FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod,
)
......@@ -40,7 +41,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
......
......@@ -10,12 +10,12 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
......
......@@ -23,13 +23,13 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoERouter,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
......
......@@ -12,11 +12,11 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
......
......@@ -10,12 +10,12 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
......
......@@ -8,10 +8,10 @@ from packaging import version
from torch.nn import Module
from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
from vllm.model_executor.layers.fused_moe import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
......
......@@ -13,11 +13,11 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
......
......@@ -6,12 +6,12 @@ from typing import Any, Optional
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment