Unverified Commit e74698c2 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Misc][Refactor] Add FusedMoERouter object (#30519)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent aa125ecf
...@@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device): ...@@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s envs.environment_variables[env_name] = lambda s=strategy: s
# Test the select_experts method # Test the select_experts method
topk_weights, topk_ids = fused_moe.select_experts( topk_weights, topk_ids = fused_moe.router.select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -11,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -11,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase, FusedMoEMethodBase,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoE,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
...@@ -48,6 +51,7 @@ def get_config() -> dict[str, Any] | None: ...@@ -48,6 +51,7 @@ def get_config() -> dict[str, Any] | None:
__all__ = [ __all__ = [
"FusedMoE", "FusedMoE",
"FusedMoERouter",
"FusedMoEConfig", "FusedMoEConfig",
"FusedMoEMethodBase", "FusedMoEMethodBase",
"UnquantizedFusedMoEMethod", "UnquantizedFusedMoEMethod",
......
...@@ -10,6 +10,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -10,6 +10,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
...@@ -109,6 +112,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -109,6 +112,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def apply( def apply(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
......
...@@ -12,6 +12,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -12,6 +12,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase, FusedMoEMethodBase,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel, FusedMoEModularKernel,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
...@@ -88,10 +89,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -88,10 +89,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def apply( def apply(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
class FusedMoERouter(ABC):
"""
FusedMoERouter is an abstract class that provides a 'select_experts'
method that is used for routing hidden states based on router logits.
"""
@property
@abstractmethod
def routing_method_type(self) -> RoutingMethodType:
raise NotImplementedError
@abstractmethod
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
Returns:
(topk_weights, topk_ids)
(tuple[torch.Tensor, torch.Tensor]):
The weights and expert ids computation result.
**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
raise NotImplementedError
...@@ -31,6 +31,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -31,6 +31,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data, init_aiter_topK_meta_data,
) )
...@@ -284,6 +285,23 @@ def maybe_roundup_hidden_size( ...@@ -284,6 +285,23 @@ def maybe_roundup_hidden_size(
return hidden_size return hidden_size
class FusedMoERouterImpl(FusedMoERouter):
def __init__(self, layer: "FusedMoE"):
super().__init__()
self.layer = layer
@property
def routing_method_type(self) -> RoutingMethodType:
return self.layer.routing_method_type
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.layer._select_experts(hidden_states, router_logits)
@CustomOp.register("fused_moe") @CustomOp.register("fused_moe")
class FusedMoE(CustomOp): class FusedMoE(CustomOp):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
...@@ -339,7 +357,7 @@ class FusedMoE(CustomOp): ...@@ -339,7 +357,7 @@ class FusedMoE(CustomOp):
is_sequence_parallel=False, is_sequence_parallel=False,
expert_mapping: list[tuple[str, str, int, str]] | None = None, expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None, n_shared_experts: int | None = None,
routing_method_type: int | None = None, routing_method_type: RoutingMethodType | None = None,
router_logits_dtype: torch.dtype | None = None, router_logits_dtype: torch.dtype | None = None,
): ):
super().__init__() super().__init__()
...@@ -529,7 +547,7 @@ class FusedMoE(CustomOp): ...@@ -529,7 +547,7 @@ class FusedMoE(CustomOp):
# ToDo: Better logic to determine the routing method type # ToDo: Better logic to determine the routing method type
if routing_method_type is not None: if routing_method_type is not None:
self.routing_method_type = routing_method_type self.routing_method_type: RoutingMethodType = routing_method_type
else: else:
if scoring_func == "sigmoid": if scoring_func == "sigmoid":
if self.use_grouped_topk: if self.use_grouped_topk:
...@@ -640,6 +658,8 @@ class FusedMoE(CustomOp): ...@@ -640,6 +658,8 @@ class FusedMoE(CustomOp):
self.batched_hidden_states: torch.Tensor | None = None self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None
self.router = FusedMoERouterImpl(self)
# Note: maybe_init_modular_kernel should only be called by # Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model. # prepare_communication_buffer_for_model.
# This is called after all weight loading and post-processing, so it # This is called after all weight loading and post-processing, so it
...@@ -1503,7 +1523,7 @@ class FusedMoE(CustomOp): ...@@ -1503,7 +1523,7 @@ class FusedMoE(CustomOp):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
) )
def select_experts( def _select_experts(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -1772,6 +1792,7 @@ class FusedMoE(CustomOp): ...@@ -1772,6 +1792,7 @@ class FusedMoE(CustomOp):
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
router=self.router,
x=staged_hidden_states, x=staged_hidden_states,
router_logits=staged_router_logits, router_logits=staged_router_logits,
) )
...@@ -1944,6 +1965,7 @@ class FusedMoE(CustomOp): ...@@ -1944,6 +1965,7 @@ class FusedMoE(CustomOp):
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
router=self.router,
x=hidden_states_combined x=hidden_states_combined
if do_naive_dispatch_combine if do_naive_dispatch_combine
else hidden_states, else hidden_states,
......
...@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( ...@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase, FusedMoEMethodBase,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEActivationFormat,
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
...@@ -285,10 +286,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -285,10 +286,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def apply( def apply(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward( return self.forward(
router=router,
layer=layer, layer=layer,
x=x, x=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -306,10 +309,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -306,10 +309,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cuda( def forward_cuda(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
...@@ -332,6 +336,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -332,6 +336,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cpu( def forward_cpu(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
...@@ -365,6 +370,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -365,6 +370,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_xpu( def forward_xpu(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
......
...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoE,
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -759,12 +760,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -759,12 +760,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -10,7 +10,11 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -10,7 +10,11 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
LinearMethodBase, LinearMethodBase,
...@@ -495,12 +499,13 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -495,12 +499,13 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -40,6 +40,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( ...@@ -40,6 +40,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts, MarlinExperts,
fused_marlin_moe, fused_marlin_moe,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
...@@ -458,6 +459,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -458,6 +459,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
...@@ -484,7 +486,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -484,7 +486,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
x_routing, _ = x x_routing, _ = x
else: else:
x_routing = x x_routing = x
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x_routing, hidden_states=x_routing,
router_logits=router_logits, router_logits=router_logits,
) )
...@@ -926,10 +928,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -926,10 +928,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
...@@ -1066,12 +1069,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1066,12 +1069,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
...@@ -1426,6 +1430,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1426,6 +1430,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
...@@ -1433,7 +1438,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1433,7 +1438,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
f"{layer.activation} not supported for Marlin MoE." f"{layer.activation} not supported for Marlin MoE."
) )
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
...@@ -1677,12 +1682,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1677,12 +1682,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
...@@ -1978,6 +1984,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1978,6 +1984,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -2290,6 +2297,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2290,6 +2297,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
): ):
...@@ -2298,7 +2306,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2298,7 +2306,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
) )
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
int8_w8a16_moe_quant_config, int8_w8a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -137,12 +138,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -137,12 +138,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -29,6 +29,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -29,6 +29,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
...@@ -997,6 +998,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -997,6 +998,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
...@@ -1051,7 +1053,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1051,7 +1053,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -16,7 +16,11 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -16,7 +16,11 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
LinearMethodBase, LinearMethodBase,
...@@ -629,6 +633,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -629,6 +633,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
...@@ -639,7 +644,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -639,7 +644,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
"fused GGUF MoE method." "fused GGUF MoE method."
) )
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoE,
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -895,12 +896,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -895,12 +896,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -9,6 +9,9 @@ from torch.nn import Module ...@@ -9,6 +9,9 @@ from torch.nn import Module
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
LinearMethodBase, LinearMethodBase,
...@@ -384,6 +387,7 @@ class XPUFp8MoEMethod(Fp8OnlineMoEMethod): ...@@ -384,6 +387,7 @@ class XPUFp8MoEMethod(Fp8OnlineMoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -14,8 +14,10 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant ...@@ -14,8 +14,10 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoE,
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -200,7 +202,9 @@ class ModelOptQuantConfigBase(QuantizationConfig): ...@@ -200,7 +202,9 @@ class ModelOptQuantConfigBase(QuantizationConfig):
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method return quant_method
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer) quant_method = self.FusedMoEMethodCls(
quant_config=self, moe_config=layer.moe_config
)
if getattr(quant_method, "backend", "") == "marlin": if getattr(quant_method, "backend", "") == "marlin":
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method return quant_method
...@@ -720,14 +724,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -720,14 +724,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def __init__( def __init__(
self, self,
quant_config: ModelOptFp8Config, quant_config: ModelOptFp8Config,
layer: FusedMoE, moe_config: FusedMoEConfig,
) -> None: ) -> None:
super().__init__(layer.moe_config) super().__init__(moe_config)
self.quant_config = quant_config self.quant_config = quant_config
assert self.quant_config.is_checkpoint_fp8_serialized assert self.quant_config.is_checkpoint_fp8_serialized
self.fp8_backend = select_fp8_moe_backend( self.fp8_backend = select_fp8_moe_backend(
block_quant=False, block_quant=False,
tp_size=layer.moe_parallel_config.tp_size, tp_size=moe_config.moe_parallel_config.tp_size,
with_lora_support=self.moe.is_lora_enabled, with_lora_support=self.moe.is_lora_enabled,
) )
self.kernel: mk.FusedMoEModularKernel | None = None self.kernel: mk.FusedMoEModularKernel | None = None
...@@ -935,6 +939,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -935,6 +939,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
...@@ -961,7 +966,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -961,7 +966,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
topk_weights, topk_ids = layer.select_experts( # Expert selection
topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
...@@ -1325,9 +1331,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1325,9 +1331,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def __init__( def __init__(
self, self,
quant_config: ModelOptNvFp4Config, quant_config: ModelOptNvFp4Config,
layer: FusedMoE, moe_config: FusedMoEConfig,
) -> None: ) -> None:
super().__init__(layer.moe_config) super().__init__(moe_config)
self.quant_config = quant_config self.quant_config = quant_config
self.nvfp4_backend = select_nvfp4_moe_backend() self.nvfp4_backend = select_nvfp4_moe_backend()
# TODO: move this type of check into the oracle. # TODO: move this type of check into the oracle.
...@@ -1597,6 +1603,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1597,6 +1603,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
...@@ -1621,7 +1628,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1621,7 +1628,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x_routing, _ = x x_routing, _ = x
else: else:
x_routing = x x_routing = x
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x_routing, hidden_states=x_routing,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import (
int4_w4a16_moe_quant_config, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config, int8_w8a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoE,
FusedMoEConfig, FusedMoEConfig,
...@@ -364,13 +365,14 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -364,13 +365,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert layer.activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -27,6 +27,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( ...@@ -27,6 +27,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts, MarlinExperts,
fused_marlin_moe, fused_marlin_moe,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts, OAITritonExperts,
UnfusedOAITritonExperts, UnfusedOAITritonExperts,
...@@ -891,6 +892,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -891,6 +892,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
...@@ -898,7 +900,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -898,7 +900,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError("EPLB is not supported for mxfp4") raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN: if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
...@@ -992,7 +994,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -992,7 +994,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
): ):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
...@@ -1119,7 +1121,8 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod): ...@@ -1119,7 +1121,8 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEConfig, FusedMoEConfig,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoERouter,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
...@@ -350,10 +351,11 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -350,10 +351,11 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
...@@ -542,6 +544,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -542,6 +544,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
...@@ -750,10 +753,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -750,10 +753,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -15,7 +15,11 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -15,7 +15,11 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
LinearMethodBase, LinearMethodBase,
...@@ -356,10 +360,11 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -356,10 +360,11 @@ class RTNMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment