Unverified Commit dc917cce authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Move `select_experts` from `FusedMoEQuantMethod` -> `FusedMoE` (#31996)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent fc56f4a0
...@@ -957,18 +957,18 @@ class MarlinMoEWeightData: ...@@ -957,18 +957,18 @@ class MarlinMoEWeightData:
) )
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe( def test_fused_marlin_moe(
a_type, a_type: ScalarType,
b_type, b_type: ScalarType,
c_type, c_type: ScalarType,
group_blocks, group_blocks: int,
m, m: int,
n, n: int,
k, k: int,
e, e: int,
topk, topk: int,
ep_size, ep_size: int,
act_order, act_order: bool,
is_k_full, is_k_full: bool,
): ):
torch.cuda.manual_seed(1) torch.cuda.manual_seed(1)
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16 group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
...@@ -1044,7 +1044,6 @@ def test_fused_marlin_moe( ...@@ -1044,7 +1044,6 @@ def test_fused_marlin_moe(
None, None,
w1_data.scales, w1_data.scales,
w2_data.scales, w2_data.scales,
score,
topk_weights, topk_weights,
topk_ids, topk_ids,
global_num_experts=e, global_num_experts=e,
...@@ -1120,7 +1119,6 @@ def test_fused_marlin_moe_with_bias(m): ...@@ -1120,7 +1119,6 @@ def test_fused_marlin_moe_with_bias(m):
w2_data.marlin_bias, w2_data.marlin_bias,
w1_data.scales, w1_data.scales,
w2_data.scales, w2_data.scales,
score,
topk_weights, topk_weights,
topk_ids, topk_ids,
global_num_experts=e, global_num_experts=e,
...@@ -1199,7 +1197,6 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int): ...@@ -1199,7 +1197,6 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
None, # bias2 None, # bias2
w1_data.scales, w1_data.scales,
w2_data.scales, w2_data.scales,
score,
topk_weights, topk_weights,
topk_ids, topk_ids,
global_num_experts=e, global_num_experts=e,
...@@ -1519,7 +1516,6 @@ def test_batched_fused_marlin_moe( ...@@ -1519,7 +1516,6 @@ def test_batched_fused_marlin_moe(
"bias2": None, "bias2": None,
"w1_scale": w1_data.scales, "w1_scale": w1_data.scales,
"w2_scale": w2_data.scales, "w2_scale": w2_data.scales,
"gating_output": score,
"global_num_experts": e, "global_num_experts": e,
"expert_map": None, "expert_map": None,
"global_scale1": w1_data.global_scale, "global_scale1": w1_data.global_scale,
......
...@@ -210,7 +210,6 @@ def fused_marlin_moe( ...@@ -210,7 +210,6 @@ def fused_marlin_moe(
bias2: torch.Tensor | None, bias2: torch.Tensor | None,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
gating_output: torch.Tensor | None,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
quant_type_id: int, quant_type_id: int,
...@@ -250,8 +249,6 @@ def fused_marlin_moe( ...@@ -250,8 +249,6 @@ def fused_marlin_moe(
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1. - w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2. - w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor|None): The output of the gating
operation (before softmax).
- g_idx1 (torch.Tensor|None): The first set of act_order indices. - g_idx1 (torch.Tensor|None): The first set of act_order indices.
- g_idx2 (torch.Tensor|None): The second set of act_order indices. - g_idx2 (torch.Tensor|None): The second set of act_order indices.
- sort_indices1 (torch.Tensor|None): The first act_order input - sort_indices1 (torch.Tensor|None): The first act_order input
...@@ -292,8 +289,6 @@ def fused_marlin_moe( ...@@ -292,8 +289,6 @@ def fused_marlin_moe(
topk = topk_ids.size(1) topk = topk_ids.size(1)
# Check constraints. # Check constraints.
if gating_output is not None:
assert gating_output.size(0) == M, "Number of tokens mismatch"
assert w1.size(1) * 16 == K, "Hidden size mismatch w1" assert w1.size(1) * 16 == K, "Hidden size mismatch w1"
assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2" assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
...@@ -381,7 +376,6 @@ def batched_fused_marlin_moe( ...@@ -381,7 +376,6 @@ def batched_fused_marlin_moe(
bias2: torch.Tensor | None, bias2: torch.Tensor | None,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
gating_output: torch.Tensor | None,
quant_type_id: int, quant_type_id: int,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -718,7 +712,6 @@ class MarlinExperts(MarlinExpertsBase): ...@@ -718,7 +712,6 @@ class MarlinExperts(MarlinExpertsBase):
bias2=self.w2_bias, bias2=self.w2_bias,
w1_scale=self.w1_scale, w1_scale=self.w1_scale,
w2_scale=self.w2_scale, w2_scale=self.w2_scale,
gating_output=None,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
global_scale1=self.g1_alphas, global_scale1=self.g1_alphas,
...@@ -833,7 +826,6 @@ class BatchedMarlinExperts(MarlinExpertsBase): ...@@ -833,7 +826,6 @@ class BatchedMarlinExperts(MarlinExpertsBase):
bias2=self.w2_bias, bias2=self.w2_bias,
w1_scale=self.w1_scale, w1_scale=self.w1_scale,
w2_scale=self.w2_scale, w2_scale=self.w2_scale,
gating_output=None,
quant_type_id=self.quant_type_id, quant_type_id=self.quant_type_id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
......
...@@ -14,9 +14,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -14,9 +14,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
) )
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
...@@ -108,11 +105,24 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -108,11 +105,24 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def method_name(self) -> str: def method_name(self) -> str:
return self.__class__.__name__ return self.__class__.__name__
@abstractmethod @property
def is_monolithic(self) -> bool:
return False
# @abstractmethod
def apply( def apply(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter, x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
# @abstractmethod
def apply_monolithic(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
......
...@@ -16,9 +16,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -16,9 +16,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel, FusedMoEModularKernel,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
) )
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -40,6 +37,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -40,6 +37,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
not self.fused_experts.supports_expert_map(), not self.fused_experts.supports_expert_map(),
) )
self.old_quant_method = old_quant_method self.old_quant_method = old_quant_method
assert not self.old_quant_method.is_monolithic
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
@staticmethod @staticmethod
...@@ -94,16 +92,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -94,16 +92,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def apply( def apply(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts( return self.fused_experts(
hidden_states=x,
router_logits=router_logits,
)
result = self.fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -115,5 +108,3 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -115,5 +108,3 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map, expert_map=None if self.disable_expert_map else layer.expert_map,
) )
return result
...@@ -522,8 +522,7 @@ class FusedMoE(CustomOp): ...@@ -522,8 +522,7 @@ class FusedMoE(CustomOp):
self.apply_router_weight_on_input = apply_router_weight_on_input self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation self.activation = activation
# TODO(bnell): in next PR move capture back to layer self.capture: Callable[[torch.Tensor], None] | None = None
capture: Callable[[torch.Tensor], None] | None = None
if ( if (
self.vllm_config.model_config is not None self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts and self.vllm_config.model_config.enable_return_routed_experts
...@@ -531,7 +530,9 @@ class FusedMoE(CustomOp): ...@@ -531,7 +530,9 @@ class FusedMoE(CustomOp):
# In dummy runs, the capturer is not initialized. # In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance() capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None: if capturer is not None:
capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids) self.capture = lambda topk_ids: capturer.capture(
self.layer_id, topk_ids
)
self.router = create_fused_moe_router( self.router = create_fused_moe_router(
top_k=top_k, top_k=top_k,
...@@ -550,7 +551,6 @@ class FusedMoE(CustomOp): ...@@ -550,7 +551,6 @@ class FusedMoE(CustomOp):
# TODO(bnell): once we can construct the MK at init time, we # TODO(bnell): once we can construct the MK at init time, we
# can make this a value. # can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype, indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
capture=capture,
) )
self.routing_method_type: RoutingMethodType = self.router.routing_method_type self.routing_method_type: RoutingMethodType = self.router.routing_method_type
...@@ -1673,12 +1673,27 @@ class FusedMoE(CustomOp): ...@@ -1673,12 +1673,27 @@ class FusedMoE(CustomOp):
staged_router_logits.copy_(router_logits, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True)
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=self, layer=self,
router=self.router,
x=staged_hidden_states, x=staged_hidden_states,
router_logits=staged_router_logits, router_logits=staged_router_logits,
) )
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=staged_hidden_states,
router_logits=staged_router_logits,
)
if self.capture is not None:
self.capture(topk_ids)
final_hidden_states = self.quant_method.apply(
layer=self,
x=staged_hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
if has_separate_shared_experts: if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple) assert not isinstance(final_hidden_states, tuple)
...@@ -1810,15 +1825,20 @@ class FusedMoE(CustomOp): ...@@ -1810,15 +1825,20 @@ class FusedMoE(CustomOp):
extra_tensors=extra_tensors, extra_tensors=extra_tensors,
) )
if extra_tensors is not None: if extra_tensors is not None:
hidden_states_combined, router_logits, extra_tensors_combined = ( (
dispatch_res orig_hidden_states,
) router_logits,
extra_tensors_combined,
) = dispatch_res
hidden_states_combined = ( hidden_states_combined = (
hidden_states_combined, orig_hidden_states,
extra_tensors_combined[0], extra_tensors_combined[0],
) )
else: else:
hidden_states_combined, router_logits = dispatch_res hidden_states_combined, router_logits = dispatch_res
orig_hidden_states = hidden_states_combined
else:
orig_hidden_states = hidden_states
# Run shared experts before matrix multiply. # Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states. # because matrix multiply maybe modify the hidden_states.
...@@ -1840,14 +1860,33 @@ class FusedMoE(CustomOp): ...@@ -1840,14 +1860,33 @@ class FusedMoE(CustomOp):
) )
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( x = hidden_states_combined if do_naive_dispatch_combine else hidden_states
# TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
# Figure out nicer way to do this.
x_orig = orig_hidden_states if do_naive_dispatch_combine else hidden_states
if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=self, layer=self,
router=self.router, x=x,
x=hidden_states_combined
if do_naive_dispatch_combine
else hidden_states,
router_logits=router_logits, router_logits=router_logits,
) )
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=x_orig,
router_logits=router_logits,
)
if self.capture is not None:
self.capture(topk_ids)
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
)
if has_separate_shared_experts: if has_separate_shared_experts:
assert self.shared_experts is not None assert self.shared_experts is not None
......
...@@ -127,7 +127,6 @@ class BaseRouter(FusedMoERouter): ...@@ -127,7 +127,6 @@ class BaseRouter(FusedMoERouter):
self.eplb_state = eplb_state self.eplb_state = eplb_state
self.enable_eplb = enable_eplb self.enable_eplb = enable_eplb
self.indices_type_getter = indices_type_getter self.indices_type_getter = indices_type_getter
self.capture: Callable[[torch.tensor], None] | None = None
def _validate_eplb_state(self) -> None: def _validate_eplb_state(self) -> None:
"""Validate that EPLB state is properly initialized if EPLB is enabled.""" """Validate that EPLB state is properly initialized if EPLB is enabled."""
...@@ -238,8 +237,4 @@ class BaseRouter(FusedMoERouter): ...@@ -238,8 +237,4 @@ class BaseRouter(FusedMoERouter):
# Step 5: Convert indices dtype # Step 5: Convert indices dtype
topk_ids = self._convert_indices_dtype(topk_ids, indices_type) topk_ids = self._convert_indices_dtype(topk_ids, indices_type)
# TODO(bnell): temporary hack until select_experts is moved into FusedMoE
if self.capture is not None:
self.capture(topk_ids)
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -55,4 +55,6 @@ class CustomRoutingRouter(BaseRouter): ...@@ -55,4 +55,6 @@ class CustomRoutingRouter(BaseRouter):
renormalize=self.renormalize, renormalize=self.renormalize,
) )
return topk_weights.to(torch.float32), topk_ids return topk_weights.to(torch.float32), topk_ids.to(
torch.int32 if indices_type is None else indices_type
)
...@@ -124,7 +124,9 @@ def fused_topk_bias( ...@@ -124,7 +124,9 @@ def fused_topk_bias(
topk_weights = scores.gather(1, topk_indices) topk_weights = scores.gather(1, topk_indices)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to(torch.int32) return topk_weights.to(torch.float32), topk_indices.to(
torch.int32 if indices_type is None else indices_type
)
class FusedTopKBiasRouter(BaseRouter): class FusedTopKBiasRouter(BaseRouter):
...@@ -176,6 +178,7 @@ class FusedTopKBiasRouter(BaseRouter): ...@@ -176,6 +178,7 @@ class FusedTopKBiasRouter(BaseRouter):
topk=self.top_k, topk=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
indices_type=indices_type,
) )
if self.routed_scaling_factor != 1.0: if self.routed_scaling_factor != 1.0:
......
...@@ -6,7 +6,6 @@ import torch ...@@ -6,7 +6,6 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
from vllm.model_executor.layers.fused_moe.router.custom_routing_router import ( from vllm.model_executor.layers.fused_moe.router.custom_routing_router import (
CustomRoutingRouter, CustomRoutingRouter,
) )
...@@ -49,7 +48,6 @@ def create_fused_moe_router( ...@@ -49,7 +48,6 @@ def create_fused_moe_router(
# eplb parameters # eplb parameters
enable_eplb: bool = False, enable_eplb: bool = False,
eplb_state: EplbLayerState = EMPTY_EPLB_STATE, eplb_state: EplbLayerState = EMPTY_EPLB_STATE,
capture: Callable[[torch.tensor], None] | None = None,
) -> FusedMoERouter: ) -> FusedMoERouter:
""" """
Factory function to create the appropriate FusedMoERouter subclass based on Factory function to create the appropriate FusedMoERouter subclass based on
...@@ -90,21 +88,16 @@ def create_fused_moe_router( ...@@ -90,21 +88,16 @@ def create_fused_moe_router(
Returns: Returns:
An instance of the appropriate FusedMoERouter subclass An instance of the appropriate FusedMoERouter subclass
""" """
router: BaseRouter
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "": if routing_strategy != "":
router = RoutingSimulatorRouter( return RoutingSimulatorRouter(
top_k=top_k, top_k=top_k,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
eplb_state=eplb_state, eplb_state=eplb_state,
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter, indices_type_getter=indices_type_getter,
) )
# TODO(bnell): this is temporary until select_experts is
# separated from apply.
router.capture = capture
return router
if use_grouped_topk: if use_grouped_topk:
assert custom_routing_function is None assert custom_routing_function is None
...@@ -113,7 +106,7 @@ def create_fused_moe_router( ...@@ -113,7 +106,7 @@ def create_fused_moe_router(
"num_expert_group and topk_group must be provided when " "num_expert_group and topk_group must be provided when "
"use_grouped_topk is True" "use_grouped_topk is True"
) )
router = GroupedTopKRouter( return GroupedTopKRouter(
top_k=top_k, top_k=top_k,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
eplb_state=eplb_state, eplb_state=eplb_state,
...@@ -127,11 +120,9 @@ def create_fused_moe_router( ...@@ -127,11 +120,9 @@ def create_fused_moe_router(
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter, indices_type_getter=indices_type_getter,
) )
router.capture = capture
return router
if custom_routing_function is not None: if custom_routing_function is not None:
router = CustomRoutingRouter( return CustomRoutingRouter(
top_k=top_k, top_k=top_k,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
eplb_state=eplb_state, eplb_state=eplb_state,
...@@ -140,11 +131,9 @@ def create_fused_moe_router( ...@@ -140,11 +131,9 @@ def create_fused_moe_router(
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter, indices_type_getter=indices_type_getter,
) )
router.capture = capture
return router
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
router = FusedTopKBiasRouter( return FusedTopKBiasRouter(
top_k=top_k, top_k=top_k,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
eplb_state=eplb_state, eplb_state=eplb_state,
...@@ -155,10 +144,8 @@ def create_fused_moe_router( ...@@ -155,10 +144,8 @@ def create_fused_moe_router(
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter, indices_type_getter=indices_type_getter,
) )
router.capture = capture
return router
router = FusedTopKRouter( return FusedTopKRouter(
top_k=top_k, top_k=top_k,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
eplb_state=eplb_state, eplb_state=eplb_state,
...@@ -167,5 +154,3 @@ def create_fused_moe_router( ...@@ -167,5 +154,3 @@ def create_fused_moe_router(
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter, indices_type_getter=indices_type_getter,
) )
router.capture = capture
return router
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -31,9 +32,6 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import ( ...@@ -31,9 +32,6 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
make_unquantized_moe_kernel, make_unquantized_moe_kernel,
select_unquantized_moe_backend, select_unquantized_moe_backend,
) )
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
...@@ -66,6 +64,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -66,6 +64,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
) )
self.kernel: mk.FusedMoEModularKernel | None = None self.kernel: mk.FusedMoEModularKernel | None = None
self._is_monolithic = current_platform.is_cpu() or current_platform.is_xpu()
@property
def is_monolithic(self) -> bool:
return self._is_monolithic
@property @property
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
...@@ -212,7 +215,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -212,7 +215,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( self.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
use_prepack=True, use_prepack=True,
...@@ -244,11 +247,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -244,11 +247,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) )
assert packed_w2_weight.size() == layer.w2_weight.size() assert packed_w2_weight.size() == layer.w2_weight.size()
layer.w2_weight.copy_(packed_w2_weight) layer.w2_weight.copy_(packed_w2_weight)
layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) self.cpu_fused_moe: Callable = cpu_fused_moe.SGLFusedMOE(layer)
else: else:
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
else: else:
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
elif current_platform.is_cuda_alike(): elif current_platform.is_cuda_alike():
self._setup_kernel( self._setup_kernel(
layer=layer, layer=layer,
...@@ -259,15 +262,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -259,15 +262,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def apply( def apply(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward( return self.forward(
router=router,
layer=layer, layer=layer,
x=x, x=x,
router_logits=router_logits, topk_weights=topk_weights,
topk_ids=topk_ids,
) )
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
...@@ -282,18 +285,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -282,18 +285,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cuda( def forward_cuda(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel assert self.kernel is not None
return self.kernel(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
result = self.kernel(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -306,24 +303,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -306,24 +303,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=layer.expert_map, expert_map=layer.expert_map,
) )
return result def forward_monolithic_cpu(
def forward_cpu(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if ( return self.cpu_fused_moe(
layer.enable_eplb is not False
or layer.eplb_state.expert_load_view is not None
or layer.eplb_state.logical_to_physical_map is not None
or layer.eplb_state.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for CPU.")
return layer.cpu_fused_moe(
layer, layer,
x, x,
layer.use_grouped_topk, layer.use_grouped_topk,
...@@ -342,21 +328,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -342,21 +328,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.activation, layer.activation,
) )
def forward_xpu( def forward_monolithic_xpu(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if ( return self.ipex_fusion(
layer.enable_eplb is not False
or layer.eplb_state.expert_load_view is not None
or layer.eplb_state.logical_to_physical_map is not None
or layer.eplb_state.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for XPU.")
return layer.ipex_fusion(
x, x,
layer.use_grouped_topk, layer.use_grouped_topk,
layer.top_k, layer.top_k,
...@@ -368,8 +346,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -368,8 +346,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) )
if current_platform.is_cpu(): if current_platform.is_cpu():
forward_native = forward_cpu forward_native: Callable = forward_monolithic_cpu
apply_monolithic = forward_monolithic_cpu
elif current_platform.is_xpu(): elif current_platform.is_xpu():
forward_native = forward_xpu forward_native = forward_monolithic_xpu
apply_monolithic = forward_monolithic_xpu
else: else:
forward_native = forward_cuda forward_native = forward_cuda
...@@ -10,7 +10,6 @@ from torch.nn import Parameter ...@@ -10,7 +10,6 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -762,15 +761,10 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -762,15 +761,10 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_marlin_moe( return fused_marlin_moe(
x, x,
layer.w13_qweight, layer.w13_qweight,
...@@ -779,7 +773,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -779,7 +773,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
getattr(layer, "w2_bias", None), getattr(layer, "w2_bias", None),
layer.w13_scales, layer.w13_scales,
layer.w2_scales, layer.w2_scales,
router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None), input_global_scale1=getattr(layer, "w13_input_global_scale", None),
......
...@@ -6,7 +6,6 @@ from typing import Any, Union ...@@ -6,7 +6,6 @@ from typing import Any, Union
import torch import torch
from packaging import version from packaging import version
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -499,16 +498,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -499,16 +498,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
# TODO(bnell): Do these need to be called on the hot path? # TODO(bnell): Do these need to be called on the hot path?
if self.quant_config.load_in_8bit: if self.quant_config.load_in_8bit:
w13, w2 = self._apply_8bit_dequant(layer) w13, w2 = self._apply_8bit_dequant(layer)
......
...@@ -21,7 +21,6 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -21,7 +21,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEActivationFormat, FusedMoEActivationFormat,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
FusedMoERouter,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod, UnquantizedFusedMoEMethod,
) )
...@@ -126,7 +125,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -126,7 +125,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module, layer: torch.nn.Module,
layer_name: str, layer_name: str,
) -> "CompressedTensorsMoEMethod": ) -> FusedMoEMethodBase:
# FusedMoE was made by combining multiple Linears so need to # FusedMoE was made by combining multiple Linears so need to
# make sure quantization config for Linear can target it # make sure quantization config for Linear can target it
quant_config._add_fused_moe_to_target_scheme_map() quant_config._add_fused_moe_to_target_scheme_map()
...@@ -345,19 +344,10 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -345,19 +344,10 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if isinstance(x, tuple):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids = router.select_experts(
hidden_states=x_routing,
router_logits=router_logits,
)
assert self.kernel is not None assert self.kernel is not None
return self.kernel( return self.kernel(
x, x,
...@@ -639,19 +629,25 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -639,19 +629,25 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
) )
def apply( @property
def is_monolithic(self) -> bool:
return (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not self.moe.moe_parallel_config.enable_eplb
)
def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert layer.activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
assert (
if (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb and not layer.enable_eplb
): )
return flashinfer_trtllm_fp4_moe( return flashinfer_trtllm_fp4_moe(
layer=layer, layer=layer,
x=x, x=x,
...@@ -665,15 +661,15 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -665,15 +661,15 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias=layer.e_score_correction_bias, e_score_correction_bias=layer.e_score_correction_bias,
) )
# Hidden_states in select_experts is only used to extract metadata def apply(
if isinstance(x, tuple): self,
x_routing, _ = x layer: FusedMoE,
else: x: torch.Tensor,
x_routing = x topk_weights: torch.Tensor,
topk_weights, topk_ids = router.select_experts( topk_ids: torch.Tensor,
hidden_states=x_routing, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
router_logits=router_logits, assert not self.is_monolithic
) assert layer.activation == "silu", "Only SiLU activation is supported."
# EPLB path # EPLB path
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
...@@ -1059,18 +1055,18 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1059,18 +1055,18 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
) )
def apply( @property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: assert self.is_monolithic
if layer.enable_eplb: assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
raise NotImplementedError(
"EPLB not supported for `FlashInfer TRTLLM FP8 MoE`."
)
assert layer.activation == "silu" assert layer.activation == "silu"
if self.block_quant: if self.block_quant:
...@@ -1116,13 +1112,16 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1116,13 +1112,16 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
topk_weights, topk_ids = router.select_experts( def apply(
hidden_states=x, self,
router_logits=router_logits, layer: FusedMoE,
) x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.kernel is not None assert self.kernel is not None
result = self.kernel( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1137,8 +1136,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1137,8 +1136,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
return result
@property @property
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
return True return True
...@@ -1257,17 +1254,12 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1257,17 +1254,12 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_experts( return fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
...@@ -1621,15 +1613,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1621,15 +1613,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_marlin_moe( return fused_marlin_moe(
x, x,
layer.w13_weight_packed, layer.w13_weight_packed,
...@@ -1638,7 +1625,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1638,7 +1625,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
None, None,
layer.w13_weight_scale, layer.w13_weight_scale,
layer.w2_weight_scale, layer.w2_weight_scale,
router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None), input_global_scale1=getattr(layer, "w13_input_global_scale", None),
...@@ -1873,17 +1859,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1873,17 +1859,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_experts( return fused_experts(
x, x,
layer.w13_weight_packed, layer.w13_weight_packed,
...@@ -2172,10 +2153,13 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2172,10 +2153,13 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
# fused_experts; quant config is not needed. # fused_experts; quant config is not needed.
return None return None
def apply( @property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -2489,19 +2473,15 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2489,19 +2473,15 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
): topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if layer.enable_eplb: if layer.enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
) )
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_w4a8_fp8, cutlass_moe_w4a8_fp8,
......
...@@ -10,7 +10,6 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -10,7 +10,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEConfig, FusedMoEConfig,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoERouter,
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -138,17 +137,12 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -138,17 +137,12 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_experts( return fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
......
...@@ -22,7 +22,6 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -22,7 +22,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
FusedMoERouter,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
...@@ -968,14 +967,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -968,14 +967,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def allow_inplace(self) -> bool: def allow_inplace(self) -> bool:
return True return True
def apply( @property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
# TODO(rob): convert this to MK. # TODO(rob): convert this to MK.
if layer.enable_eplb: if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.") raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
...@@ -1026,13 +1030,16 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1026,13 +1030,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
topk_weights, topk_ids = router.select_experts( def apply(
hidden_states=x, self,
router_logits=router_logits, layer: FusedMoE,
) x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None assert self.kernel is not None
result = self.kernel( assert not self.is_monolithic
return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1045,8 +1052,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1045,8 +1052,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
return result
class Fp8OnlineMoEMethod(Fp8MoEMethod): class Fp8OnlineMoEMethod(Fp8MoEMethod):
"""MoE method for online FP8 quantization. """MoE method for online FP8 quantization.
......
...@@ -12,7 +12,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter ...@@ -12,7 +12,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -633,9 +632,9 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -633,9 +632,9 @@ class GGUFMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
if layer.apply_router_weight_on_input: if layer.apply_router_weight_on_input:
...@@ -644,10 +643,6 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -644,10 +643,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
"fused GGUF MoE method." "fused GGUF MoE method."
) )
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_moe_gguf( return fused_moe_gguf(
x, x,
layer.w13_qweight, layer.w13_qweight,
......
...@@ -10,7 +10,6 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE ...@@ -10,7 +10,6 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -898,15 +897,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -898,15 +897,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_marlin_moe( return fused_marlin_moe(
x, x,
layer.w13_qweight, layer.w13_qweight,
...@@ -915,7 +909,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -915,7 +909,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
getattr(layer, "w2_bias", None), getattr(layer, "w2_bias", None),
layer.w13_scales, layer.w13_scales,
layer.w2_scales, layer.w2_scales,
router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None), input_global_scale1=getattr(layer, "w13_input_global_scale", None),
......
...@@ -8,9 +8,6 @@ from packaging import version ...@@ -8,9 +8,6 @@ from packaging import version
from torch.nn import Module from torch.nn import Module
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
...@@ -384,10 +381,13 @@ class XPUFp8MoEMethod(Fp8OnlineMoEMethod): ...@@ -384,10 +381,13 @@ class XPUFp8MoEMethod(Fp8OnlineMoEMethod):
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
return None return None
def apply( @property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -13,7 +13,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -13,7 +13,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -945,14 +944,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -945,14 +944,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
a2_scale=a2_scale, a2_scale=a2_scale,
) )
def apply( @property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
if layer.enable_eplb: if layer.enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend." "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
...@@ -975,11 +978,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -975,11 +978,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
# Expert selection def apply(
topk_weights, topk_ids = router.select_experts( self,
hidden_states=x, layer: FusedMoE,
router_logits=router_logits, x: torch.Tensor,
) topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
# TODO(rob): this validation should happen at kernel selection # TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here. # time in the oracle rather than here.
...@@ -990,7 +996,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -990,7 +996,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
assert self.kernel is not None assert self.kernel is not None
result = self.kernel( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1003,8 +1009,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -1003,8 +1009,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
return result
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
...@@ -1629,17 +1633,25 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1629,17 +1633,25 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
return True return True
def apply( @property
def is_monolithic(self) -> bool:
return (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not self.moe.moe_parallel_config.enable_eplb
)
def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if ( assert self.is_monolithic
assert (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb and not layer.enable_eplb
): )
return flashinfer_trtllm_fp4_moe( return flashinfer_trtllm_fp4_moe(
layer=layer, layer=layer,
x=x, x=x,
...@@ -1653,15 +1665,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1653,15 +1665,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias=layer.e_score_correction_bias, e_score_correction_bias=layer.e_score_correction_bias,
) )
# Hidden_states in select_experts is only used to extract metadata def apply(
if isinstance(x, tuple): self,
x_routing, _ = x layer: FusedMoE,
else: x: torch.Tensor,
x_routing = x topk_weights: torch.Tensor,
topk_weights, topk_ids = router.select_experts( topk_ids: torch.Tensor,
hidden_states=x_routing, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
router_logits=router_logits, assert not self.is_monolithic
)
# EPLB path # EPLB path
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
......
...@@ -6,7 +6,6 @@ from typing import Any, Optional ...@@ -6,7 +6,6 @@ from typing import Any, Optional
import torch import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
int4_w4a16_moe_quant_config, int4_w4a16_moe_quant_config,
...@@ -365,17 +364,13 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -365,17 +364,13 @@ class MoeWNA16Method(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert layer.activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_experts( return fused_experts(
x, x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment