Unverified Commit d1481ba7 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Introduce MoERunner abstraction and move execution logic from...


[MoE Refactor] Introduce MoERunner abstraction and move execution logic from FusedMoE to DefaultMoERunner (#32344)
Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent dc6de33c
......@@ -958,6 +958,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
......@@ -980,7 +981,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)
......@@ -1524,6 +1525,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
......@@ -1551,7 +1553,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)
......
......@@ -367,6 +367,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
......
......@@ -900,6 +900,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
if layer.enable_eplb:
......
......@@ -419,6 +419,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
......@@ -607,6 +608,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
......@@ -977,6 +979,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.emulate:
if (
......
......@@ -816,10 +816,14 @@ class Worker(WorkerBase):
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=get_tp_group().world_size,
tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment