Unverified Commit 8f066146 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE][Refactor] Make select_experts a non-static method (#29067)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent cec418b5
...@@ -11,7 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -11,7 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, apply_flashinfer_per_tensor_scale_fp8,
flashinfer_cutlass_moe_fp8, flashinfer_cutlass_moe_fp8,
...@@ -151,14 +150,11 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -151,14 +150,11 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True) td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids = Llama4MoE.custom_routing_function(
hidden_states=td.hidden_states, hidden_states=td.hidden_states,
router_logits=score, gating_output=score,
use_grouped_topk=False, topk=topk,
top_k=topk,
renormalize=False, renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax",
) )
quant_config = fp8_w8a8_moe_quant_config( quant_config = fp8_w8a8_moe_quant_config(
...@@ -219,14 +215,11 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -219,14 +215,11 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
) )
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids = Llama4MoE.custom_routing_function(
hidden_states=td.hidden_states, hidden_states=td.hidden_states,
router_logits=score, gating_output=score,
use_grouped_topk=False, topk=topk,
top_k=topk,
renormalize=False, renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax",
) )
quant_config = fp8_w8a8_moe_quant_config( quant_config = fp8_w8a8_moe_quant_config(
......
...@@ -9,9 +9,16 @@ different routing strategies and analyze their performance, including ...@@ -9,9 +9,16 @@ different routing strategies and analyze their performance, including
integration tests with FusedMoE layer. integration tests with FusedMoE layer.
""" """
import tempfile
import pytest import pytest
import torch import torch
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import ( from vllm.model_executor.layers.fused_moe.routing_simulator import (
DistributionBasedRouting, DistributionBasedRouting,
RoutingSimulator, RoutingSimulator,
...@@ -89,6 +96,28 @@ def test_routing_strategy_integration(monkeypatch, device): ...@@ -89,6 +96,28 @@ def test_routing_strategy_integration(monkeypatch, device):
# Test different routing strategies # Test different routing strategies
strategies = RoutingSimulator.get_available_strategies() strategies = RoutingSimulator.get_available_strategies()
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
fused_moe = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=0,
use_grouped_topk=False,
renormalize=True,
)
for strategy in strategies: for strategy in strategies:
# Set environment variable # Set environment variable
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY" env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
...@@ -98,13 +127,9 @@ def test_routing_strategy_integration(monkeypatch, device): ...@@ -98,13 +127,9 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s envs.environment_variables[env_name] = lambda s=strategy: s
# Test the select_experts method # Test the select_experts method
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = fused_moe.select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=top_k,
use_grouped_topk=False,
renormalize=True,
indices_type=torch.long,
) )
# Verify output shapes # Verify output shapes
......
...@@ -90,10 +90,14 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -90,10 +90,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def allow_inplace(self) -> bool: def allow_inplace(self) -> bool:
return False return False
@property
def method_name(self) -> str:
return self.__class__.__name__
@abstractmethod @abstractmethod
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
......
...@@ -66,6 +66,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -66,6 +66,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def allow_inplace(self) -> bool: def allow_inplace(self) -> bool:
return self.old_quant_method.allow_inplace return self.old_quant_method.allow_inplace
@property
def method_name(self) -> str:
return self.old_quant_method.method_name
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -84,7 +88,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -84,7 +88,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -105,42 +109,9 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -105,42 +109,9 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# Is getattr needed?
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
if enable_eplb:
if self.supports_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
else:
raise NotImplementedError(
"EPLB is not supported for "
f"{self.old_quant_method.__class__.__name__}."
)
topk_weights, topk_ids, zero_expert_result = layer.select_experts( topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
) )
result = self.fused_experts( result = self.fused_experts(
...@@ -156,7 +127,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -156,7 +127,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
expert_map=None if self.disable_expert_map else expert_map, expert_map=None if self.disable_expert_map else expert_map,
) )
if zero_expert_num != 0 and zero_expert_type is not None: if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), ( assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported" "Shared + zero experts are mutually exclusive not yet supported"
) )
......
...@@ -1510,30 +1510,11 @@ class FusedMoE(CustomOp): ...@@ -1510,30 +1510,11 @@ class FusedMoE(CustomOp):
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
) )
@staticmethod
def select_experts( def select_experts(
self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
use_grouped_topk: bool,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
indices_type: torch.dtype | None = None,
enable_eplb: bool = False,
expert_map: torch.Tensor | None = None,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
global_num_experts: int | None = None,
zero_expert_num: int | None = None,
zero_expert_type: str | None = None,
num_fused_shared_experts: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Route the input hidden states to the top-k experts based on the Route the input hidden states to the top-k experts based on the
router logits. router logits.
...@@ -1552,6 +1533,27 @@ class FusedMoE(CustomOp): ...@@ -1552,6 +1533,27 @@ class FusedMoE(CustomOp):
fused_topk_bias, fused_topk_bias,
) )
if self.enable_eplb:
if self.quant_method.supports_eplb:
if self.expert_load_view is None:
raise ValueError(
"enable_eplb=True requiere expert_load_view != None"
)
if self.logical_to_physical_map is None:
raise ValueError(
"enable_eplb=True requiere logical_to_physical_map != None"
)
if self.logical_replica_count is None:
raise ValueError(
"enable_eplb=True requiere logical_replica_count != None"
)
else:
raise NotImplementedError(
f"EPLB is not supported for {self.quant_method.method_name}."
)
indices_type = self.quant_method.topk_indices_dtype
# Check if we should use a routing simulation strategy # Check if we should use a routing simulation strategy
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "": if routing_strategy != "":
...@@ -1559,20 +1561,20 @@ class FusedMoE(CustomOp): ...@@ -1559,20 +1561,20 @@ class FusedMoE(CustomOp):
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
strategy_name=routing_strategy, strategy_name=routing_strategy,
top_k=top_k, top_k=self.top_k,
indices_type=indices_type, indices_type=indices_type,
) )
# DeepSeekv2 uses grouped_top_k # DeepSeekv2 uses grouped_top_k
elif use_grouped_topk: elif self.use_grouped_topk:
assert topk_group is not None assert self.topk_group is not None
assert num_expert_group is not None assert self.num_expert_group is not None
if rocm_aiter_ops.is_fused_moe_enabled(): if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert num_fused_shared_experts == 0 assert self.num_fused_shared_experts == 0
grouped_topk_impl = partial( grouped_topk_impl = partial(
rocm_aiter_grouped_topk, rocm_aiter_grouped_topk,
num_fused_shared_experts=num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
) )
else: else:
grouped_topk_impl = grouped_topk grouped_topk_impl = grouped_topk
...@@ -1580,50 +1582,46 @@ class FusedMoE(CustomOp): ...@@ -1580,50 +1582,46 @@ class FusedMoE(CustomOp):
topk_weights, topk_ids = grouped_topk_impl( topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=self.top_k,
renormalize=renormalize, renormalize=self.renormalize,
num_expert_group=num_expert_group, num_expert_group=self.num_expert_group,
topk_group=topk_group, topk_group=self.topk_group,
scoring_func=scoring_func, scoring_func=self.scoring_func,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
) )
elif e_score_correction_bias is not None: elif self.e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias( topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
e_score_correction_bias=e_score_correction_bias.data, e_score_correction_bias=self.e_score_correction_bias.data,
topk=top_k, topk=self.top_k,
renormalize=renormalize, renormalize=self.renormalize,
) )
if routed_scaling_factor != 1.0: if self.routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor topk_weights *= self.routed_scaling_factor
elif custom_routing_function is None: elif self.custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk( topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=self.top_k,
renormalize=renormalize, renormalize=self.renormalize,
indices_type=indices_type, indices_type=indices_type,
) )
else: else:
topk_weights, topk_ids = custom_routing_function( topk_weights, topk_ids = self.custom_routing_function(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=self.top_k,
renormalize=renormalize, renormalize=self.renormalize,
) )
if enable_eplb: if self.enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
topk_ids = eplb_map_to_physical_and_record( topk_ids = eplb_map_to_physical_and_record(
topk_ids=topk_ids, topk_ids=topk_ids,
expert_load_view=expert_load_view, expert_load_view=self.expert_load_view,
logical_to_physical_map=logical_to_physical_map, logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=logical_replica_count, logical_replica_count=self.logical_replica_count,
) )
if (indices_type is not None) and topk_ids.dtype != indices_type: if (indices_type is not None) and topk_ids.dtype != indices_type:
...@@ -1633,16 +1631,16 @@ class FusedMoE(CustomOp): ...@@ -1633,16 +1631,16 @@ class FusedMoE(CustomOp):
# Compute zero expert result if needed # Compute zero expert result if needed
if ( if (
zero_expert_num is not None self.zero_expert_num is not None
and zero_expert_num > 0 and self.zero_expert_num > 0
and zero_expert_type is not None and self.zero_expert_type is not None
and global_num_experts is not None and self.global_num_experts is not None
): ):
zero_expert_result = zero_experts_compute_triton( zero_expert_result = zero_experts_compute_triton(
expert_indices=topk_ids, expert_indices=topk_ids,
expert_scales=topk_weights, expert_scales=topk_weights,
num_experts=global_num_experts, num_experts=self.global_num_experts,
zero_expert_type=zero_expert_type, zero_expert_type=self.zero_expert_type,
hidden_states=hidden_states, hidden_states=hidden_states,
) )
else: else:
......
...@@ -331,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -331,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cuda( def forward_cuda(
self, self,
layer: torch.nn.Module, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, use_grouped_topk: bool,
top_k: int, top_k: int,
...@@ -352,31 +352,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -352,31 +352,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
topk_weights, topk_ids, zero_expert_result = layer.select_experts( topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
num_fused_shared_experts=layer.num_fused_shared_experts,
) )
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
...@@ -415,7 +393,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -415,7 +393,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map, expert_map=expert_map,
) )
if zero_expert_num != 0 and zero_expert_type is not None: if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), ( assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported" "Shared + zero experts are mutually exclusive not yet supported"
) )
...@@ -425,7 +403,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -425,7 +403,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cpu( def forward_cpu(
self, self,
layer: torch.nn.Module, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, use_grouped_topk: bool,
top_k: int, top_k: int,
...@@ -474,7 +452,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -474,7 +452,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_xpu( def forward_xpu(
self, self,
layer: torch.nn.Module, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, use_grouped_topk: bool,
top_k: int, top_k: int,
...@@ -515,7 +493,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -515,7 +493,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_tpu( def forward_tpu(
self, self,
layer: torch.nn.Module, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, use_grouped_topk: bool,
top_k: int, top_k: int,
......
...@@ -597,7 +597,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -597,7 +597,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -618,24 +618,11 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -618,24 +618,11 @@ class AWQMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError("EPLB not supported for `AWQMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
return fused_marlin_moe( return fused_marlin_moe(
......
...@@ -495,7 +495,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -495,7 +495,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -518,25 +518,11 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -518,25 +518,11 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb: topk_weights, topk_ids, _ = layer.select_experts(
raise NotImplementedError(
"EPLB not supported for `BitsAndBytesMoEMethod` yet."
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
# TODO(bnell): Do these need to be called on the hot path?
if self.quant_config.load_in_8bit: if self.quant_config.load_in_8bit:
w13, w2 = self._apply_8bit_dequant(layer) w13, w2 = self._apply_8bit_dequant(layer)
else: else:
......
...@@ -511,7 +511,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -511,7 +511,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -532,16 +532,17 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -532,16 +532,17 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
)
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if ( if (
self.allow_flashinfer self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
): ):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
)
return flashinfer_trtllm_fp4_moe( return flashinfer_trtllm_fp4_moe(
layer=layer, layer=layer,
x=x, x=x,
...@@ -554,19 +555,9 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -554,19 +555,9 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
) )
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
if self.use_marlin: if self.use_marlin:
...@@ -1109,7 +1100,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1109,7 +1100,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -1130,31 +1121,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1130,31 +1121,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb: topk_weights, topk_ids, _ = layer.select_experts(
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
num_fused_shared_experts=layer.num_fused_shared_experts,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
) )
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
...@@ -1377,7 +1346,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1377,7 +1346,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -1398,26 +1367,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1398,26 +1367,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet."
)
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
return fused_experts( return fused_experts(
...@@ -1738,7 +1692,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1738,7 +1692,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -1759,26 +1713,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1759,26 +1713,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet."
)
assert activation == "silu", f"{activation} not supported for Marlin MoE." assert activation == "silu", f"{activation} not supported for Marlin MoE."
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
return fused_marlin_moe( return fused_marlin_moe(
...@@ -2001,7 +1940,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -2001,7 +1940,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -2022,43 +1961,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -2022,43 +1961,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
if expert_load_view is None:
raise ValueError("enable_eplb=True requiere expert_load_view != None")
if logical_to_physical_map is None:
raise ValueError(
"enable_eplb=True requiere logical_to_physical_map != None"
)
if logical_replica_count is None:
raise ValueError(
"enable_eplb=True requiere logical_replica_count != None"
)
if not isinstance(layer, FusedMoE):
raise TypeError(
"EPLB is only supported when `layer` is a instance of FusedMoE."
)
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
num_fused_shared_experts=getattr(layer, "num_fused_shared_experts", 0),
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
) )
return fused_experts( return fused_experts(
......
...@@ -137,7 +137,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -137,7 +137,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -158,26 +158,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -158,26 +158,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ExpertsInt8MoEMethod` yet."
)
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
return fused_experts( return fused_experts(
......
...@@ -1140,7 +1140,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1140,7 +1140,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -1216,31 +1216,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1216,31 +1216,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
zero_expert_num = getattr(layer, "zero_expert_num", 0) select_result = layer.select_experts(
zero_expert_type = getattr(layer, "zero_expert_type", None)
select_result = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
num_fused_shared_experts=layer.num_fused_shared_experts,
) )
topk_weights, topk_ids, zero_expert_result = select_result topk_weights, topk_ids, zero_expert_result = select_result
...@@ -1322,7 +1300,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1322,7 +1300,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.allow_cutlass_block_scaled_grouped_gemm self.allow_cutlass_block_scaled_grouped_gemm
), ),
) )
if zero_expert_num != 0 and zero_expert_type is not None:
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), ( assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported" "Shared + zero experts are mutually exclusive not yet supported"
) )
......
...@@ -621,7 +621,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -621,7 +621,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -642,9 +642,6 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -642,9 +642,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input: if apply_router_weight_on_input:
raise NotImplementedError( raise NotImplementedError(
...@@ -652,19 +649,9 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -652,19 +649,9 @@ class GGUFMoEMethod(FusedMoEMethodBase):
"fused GGUF MoE method." "fused GGUF MoE method."
) )
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
return fused_moe_gguf( return fused_moe_gguf(
x, x,
......
...@@ -722,7 +722,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -722,7 +722,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -743,26 +743,11 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -743,26 +743,11 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `GPTQMarlinMoEMethod` yet."
)
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
return fused_marlin_moe( return fused_marlin_moe(
......
...@@ -696,7 +696,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -696,7 +696,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -717,12 +717,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -717,12 +717,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if layer.enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet." "EPLB not supported for `ModelOptFp8MoEMethod` yet."
) )
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
assert activation == "silu", ( assert activation == "silu", (
f"Expected 'silu' activation but got {activation}" f"Expected 'silu' activation but got {activation}"
) )
...@@ -740,19 +739,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -740,19 +739,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
# Expert selection # Expert selection
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
...@@ -1459,7 +1448,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1459,7 +1448,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -1480,16 +1469,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1480,16 +1469,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if ( if (
self.allow_flashinfer self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
): ):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
return flashinfer_trtllm_fp4_moe( return flashinfer_trtllm_fp4_moe(
layer=layer, layer=layer,
x=x, x=x,
...@@ -1502,19 +1491,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1502,19 +1491,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
) )
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
if self.use_marlin: if self.use_marlin:
......
...@@ -359,7 +359,7 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -359,7 +359,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -380,25 +380,12 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -380,25 +380,12 @@ class MoeWNA16Method(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
return fused_experts( return fused_experts(
......
...@@ -862,7 +862,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -862,7 +862,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -887,18 +887,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -887,18 +887,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError("EPLB is not supported for mxfp4") raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN: if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
) )
return fused_marlin_moe( return fused_marlin_moe(
...@@ -989,17 +980,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -989,17 +980,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
): ):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
) )
# Backend-specific preparation # Backend-specific preparation
......
...@@ -334,7 +334,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -334,7 +334,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -355,24 +355,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -355,24 +355,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb: topk_weights, topk_ids, _ = layer.select_experts(
raise NotImplementedError(
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet."
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
...@@ -609,7 +594,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -609,7 +594,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -630,24 +615,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -630,24 +615,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb: topk_weights, topk_ids, _ = layer.select_experts(
raise NotImplementedError(
"EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
if not self.emulate: if not self.emulate:
......
...@@ -356,7 +356,7 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -356,7 +356,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -377,22 +377,9 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -377,22 +377,9 @@ class RTNMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb: topk_weights, topk_ids, _ = layer.select_experts(
raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
return fused_marlin_moe( return fused_marlin_moe(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment