Unverified Commit 327a02d8 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Separate Router into OO Classes (#30623)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 2f03035a
...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEConfig, FusedMoEConfig,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoERouter,
) )
from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
...@@ -27,7 +28,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( ...@@ -27,7 +28,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts, MarlinExperts,
fused_marlin_moe, fused_marlin_moe,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts, OAITritonExperts,
UnfusedOAITritonExperts, UnfusedOAITritonExperts,
...@@ -936,9 +936,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -936,9 +936,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.apply_router_weight_on_input, layer.apply_router_weight_on_input,
layer.scoring_func, layer.scoring_func,
layer.activation, layer.activation,
layer.expert_load_view, layer.eplb_state.expert_load_view,
layer.logical_to_physical_map, layer.eplb_state.logical_to_physical_map,
layer.logical_replica_count, layer.eplb_state.logical_replica_count,
), "MXFP4 are not supported with this configuration." ), "MXFP4 are not supported with this configuration."
if ( if (
......
...@@ -548,7 +548,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -548,7 +548,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -10,12 +10,12 @@ import torch ...@@ -10,12 +10,12 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoE,
FusedMoEMethodBase, FusedMoEMethodBase,
......
...@@ -201,6 +201,7 @@ class Ernie4_5_MoeMoE(nn.Module): ...@@ -201,6 +201,7 @@ class Ernie4_5_MoeMoE(nn.Module):
e_score_correction_bias=self.gate.e_score_correction_bias, e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
router_logits_dtype=torch.float32,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
...@@ -269,6 +269,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -269,6 +269,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
quant_config=quant_config, quant_config=quant_config,
e_score_correction_bias=self.e_score_correction_bias[0], e_score_correction_bias=self.e_score_correction_bias[0],
prefix=f"{prefix}.text_experts", prefix=f"{prefix}.text_experts",
router_logits_dtype=torch.float32,
) )
else: else:
self.text_experts = Ernie4_5_VLMoeMLP( self.text_experts = Ernie4_5_VLMoeMLP(
...@@ -306,6 +307,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -306,6 +307,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
quant_config=quant_config, quant_config=quant_config,
e_score_correction_bias=self.e_score_correction_bias[1], e_score_correction_bias=self.e_score_correction_bias[1],
prefix=f"{prefix}.vision_experts", prefix=f"{prefix}.vision_experts",
router_logits_dtype=torch.float32,
) )
else: else:
self.vision_experts = Ernie4_5_VLMoeMLP( self.vision_experts = Ernie4_5_VLMoeMLP(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment