Unverified Commit d1481ba7 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Introduce MoERunner abstraction and move execution logic from...


[MoE Refactor] Introduce MoERunner abstraction and move execution logic from FusedMoE to DefaultMoERunner (#32344)
Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent dc6de33c
...@@ -958,6 +958,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -958,6 +958,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert not self.is_monolithic
...@@ -980,7 +981,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -980,7 +981,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x), shared_experts_input=shared_experts_input,
) )
...@@ -1524,6 +1525,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1524,6 +1525,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert not self.is_monolithic
...@@ -1551,7 +1553,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1551,7 +1553,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x), shared_experts_input=shared_experts_input,
) )
......
...@@ -367,6 +367,7 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -367,6 +367,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
......
...@@ -900,6 +900,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -900,6 +900,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert not self.is_monolithic
if layer.enable_eplb: if layer.enable_eplb:
......
...@@ -419,6 +419,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -419,6 +419,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
...@@ -607,6 +608,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -607,6 +608,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts, rocm_aiter_fused_experts,
...@@ -977,6 +979,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -977,6 +979,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.emulate: if not self.emulate:
if ( if (
......
...@@ -816,10 +816,14 @@ class Worker(WorkerBase): ...@@ -816,10 +816,14 @@ class Worker(WorkerBase):
for module in moe_modules: for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make( module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=get_tp_group().world_size, tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size, pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size, dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config, vllm_parallel_config=parallel_config,
) )
module.moe_config.moe_parallel_config = module.moe_parallel_config module.moe_config.moe_parallel_config = module.moe_parallel_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment