Unverified Commit d1481ba7 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Introduce MoERunner abstraction and move execution logic from...


[MoE Refactor] Introduce MoERunner abstraction and move execution logic from FusedMoE to DefaultMoERunner (#32344)
Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent dc6de33c
......@@ -32,7 +32,7 @@ th {
| Backend | Output act. format | Quant. types | Quant. format | Async | Apply Weight On Input | Subclass |
|---------|--------------------|--------------|---------------|-------|-----------------------|-----------|
| naive | standard | all<sup>1</sup> | G,A,T | N | <sup>6</sup> | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE.forward_impl] |
| naive | standard | all<sup>1</sup> | G,A,T | N | <sup>6</sup> | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE |
| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] |
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
......
......@@ -585,6 +585,7 @@ def make_modular_kernel(
tp_size_=get_tensor_model_parallel_world_size(),
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=1,
vllm_parallel_config=vllm_config.parallel_config,
)
......@@ -594,6 +595,7 @@ def make_modular_kernel(
hidden_dim=config.K,
intermediate_size_per_partition=config.N,
num_local_experts=config.num_local_experts,
num_logical_experts=config.E,
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
max_num_tokens=next_power_of_2(config.M),
......
......@@ -52,6 +52,7 @@ def make_dummy_moe_config(
hidden_dim=hidden_dim,
intermediate_size_per_partition=intermediate_size_per_partition,
num_local_experts=num_experts,
num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation="silu",
in_dtype=in_dtype,
......
......@@ -913,12 +913,16 @@ class FusedMoEParallelConfig:
pcp_rank: int
dp_rank: int
ep_rank: int
sp_size: int
use_ep: bool # whether to use EP or not
all2all_backend: str # all2all backend for MoE communication
is_sequence_parallel: bool # whether sequence parallelism is used
enable_eplb: bool # whether to enable expert load balancing
@property
def is_sequence_parallel(self) -> bool:
return self.sp_size > 1
@property
def use_all2all_kernels(self):
return self.dp_size > 1 and self.use_ep
......@@ -974,6 +978,7 @@ class FusedMoEParallelConfig:
tp_size_: int,
pcp_size_: int,
dp_size_: int,
sp_size_: int,
vllm_parallel_config: ParallelConfig,
) -> "FusedMoEParallelConfig":
"""
......@@ -1073,9 +1078,9 @@ class FusedMoEParallelConfig:
dp_rank=dp_rank,
ep_size=1,
ep_rank=0,
sp_size=sp_size_,
use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend,
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb,
)
# DP + EP / TP + EP / DP + TP + EP
......@@ -1093,9 +1098,9 @@ class FusedMoEParallelConfig:
dp_rank=dp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
sp_size=sp_size_,
use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend,
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb,
)
......@@ -1111,10 +1116,10 @@ class FusedMoEParallelConfig:
dp_rank=0,
ep_size=1,
ep_rank=0,
sp_size=1,
use_ep=False,
all2all_backend="naive",
enable_eplb=False,
is_sequence_parallel=False,
)
......@@ -1126,6 +1131,7 @@ class FusedMoEConfig:
hidden_dim: int
intermediate_size_per_partition: int
num_local_experts: int
num_logical_experts: int
activation: str
device: torch.device | str
routing_method: RoutingMethodType
......@@ -1175,6 +1181,14 @@ class FusedMoEConfig:
def ep_size(self):
return self.moe_parallel_config.ep_size
@property
def sp_size(self):
return self.moe_parallel_config.sp_size
@property
def is_sequence_parallel(self):
return self.moe_parallel_config.is_sequence_parallel
@property
def tp_rank(self):
return self.moe_parallel_config.tp_rank
......
......@@ -121,17 +121,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def is_monolithic(self) -> bool:
return False
# @abstractmethod
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
# @abstractmethod
def apply_monolithic(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
......
......@@ -89,6 +89,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
return self.moe_mk(
......@@ -101,5 +102,5 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)
......@@ -1228,7 +1228,7 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
shared_experts_input: torch.Tensor | None = None,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
class MoERunner(ABC):
"""
Abstract base class for Mixture of Experts (MoE) runners.
This class defines the interface that all MoE runner implementations must follow.
MoE runners are responsible for executing the forward pass of MoE layers, handling
expert routing, and managing tensor parallel operations.
"""
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def must_reduce_shared_expert_outputs(self) -> bool:
raise NotImplementedError
@abstractmethod
def maybe_all_reduce_tensor_model_parallel(
self,
final_hidden_states: torch.Tensor,
):
raise NotImplementedError
......@@ -18,70 +18,6 @@ class SharedFusedMoE(FusedMoE):
can be interleaved with the fused all2all dispatch communication step.
"""
def __init__(
self,
shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None,
use_overlapped: bool = True,
routed_input_transform: torch.nn.Module | None = None,
**kwargs,
):
# Pass has_shared_experts so FusedMoE.__init__ can set disable_inplace
# without accessing self.shared_experts (submodules cannot be set before
# Module.__init__()).
kwargs["has_shared_experts"] = shared_experts is not None
super().__init__(**kwargs)
self._shared_experts = shared_experts
self._routed_input_transform = routed_input_transform
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
# - we are using flashinfer with DP, since there nothing to gain
# - we are using marlin kernels
backend = self.moe_parallel_config.all2all_backend
self.use_overlapped = (
use_overlapped
and not (
(self.enable_eplb and backend != "allgather_reducescatter")
or self.moe_parallel_config.use_fi_all2allv_kernels
)
and self._shared_experts is not None
)
self._gate = gate
@property
def shared_experts(self) -> torch.nn.Module | None:
return self._shared_experts if self.use_overlapped else None
@property
def gate(self) -> torch.nn.Module | None:
return self._gate if self.use_overlapped else None
@property
def is_internal_router(self) -> bool:
return self.gate is not None
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Apply transform for routed experts (e.g., latent projection).
This is called by FusedMoE.forward_native. The original hidden_states
is saved separately so shared experts get [S, hidden_size] while
routed experts get the transformed [S, moe_latent_size].
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
moved inside SharedFusedMoE to all-reduce on the smaller latent
dimension.
"""
if self._routed_input_transform is not None:
result = self._routed_input_transform(hidden_states)
# ReplicatedLinear returns (output, extra_bias) tuple.
# We only need the output tensor; extra_bias is not used here.
if isinstance(result, tuple):
return result[0]
return result
return hidden_states
def forward(
self,
hidden_states: torch.Tensor,
......
......@@ -55,6 +55,8 @@ logger = init_logger(__name__)
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
# --8<-- [end:unquantized_fused_moe]
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.unquantized_backend = select_unquantized_moe_backend(
......@@ -90,8 +92,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_cuda(layer, x, topk_weights, topk_ids)
return self.forward_cuda(layer, x, topk_weights, topk_ids, shared_experts_input)
@property
def is_monolithic(self) -> bool:
......@@ -293,12 +296,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
layer=layer,
x=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_experts_input,
)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
......@@ -316,6 +321,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None
......@@ -329,6 +335,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
shared_experts_input=shared_experts_input,
)
def forward_monolithic_cuda(
......
......@@ -764,6 +764,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return fused_marlin_moe(
x,
......
......@@ -501,6 +501,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
......
......@@ -349,6 +349,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
return self.moe_mk(
......@@ -361,7 +362,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)
......@@ -645,6 +646,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert layer.activation == "silu", "Only SiLU activation is supported."
......@@ -673,7 +675,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)
......@@ -1064,6 +1066,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.moe_mk is not None
......@@ -1079,7 +1082,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)
@property
......@@ -1203,6 +1206,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -1713,6 +1717,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel_backend == "Marlin"
return fused_marlin_moe(
......@@ -1961,6 +1966,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -2575,6 +2581,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if layer.enable_eplb:
raise NotImplementedError(
......
......@@ -140,6 +140,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
......
......@@ -1010,6 +1010,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
assert not self.is_monolithic
......@@ -1023,7 +1024,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)
......
......@@ -635,6 +635,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
if layer.apply_router_weight_on_input:
......
......@@ -900,6 +900,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return fused_marlin_moe(
x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment