Unverified Commit 7cf56a59 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Make SharedExperts class for use with DefaultMoERunner (#35153)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 5e30e9b9
......@@ -603,7 +603,6 @@ def make_shared_experts(
def modular_triton_fused_moe(
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
shared_experts: torch.nn.Module | None = None,
) -> FusedMoEKernel:
return FusedMoEKernel(
maybe_make_prepare_finalize(
......@@ -613,6 +612,5 @@ def modular_triton_fused_moe(
use_monolithic=False,
),
TritonExperts(moe_config, quant_config),
shared_experts,
inplace=False,
)
......@@ -103,6 +103,7 @@ pushd "$WORKSPACE"
echo "Downloading NVSHMEM ${NVSHMEM_VER} for ${NVSHMEM_SUBDIR} ..."
curl -fSL "${NVSHMEM_URL}" -o "${NVSHMEM_FILE}"
tar -xf "${NVSHMEM_FILE}"
rm -rf nvshmem
mv "${NVSHMEM_FILE%.tar.xz}" nvshmem
rm -f "${NVSHMEM_FILE}"
rm -rf nvshmem/lib/bin nvshmem/lib/share
......
......@@ -410,8 +410,7 @@ class ElasticEPScalingExecutor:
# for the new EP size by resetting quant_method to base
for module in moe_modules:
if hasattr(module.quant_method, "old_quant_method"):
module.quant_method = module.quant_method.old_quant_method
module.runner = module._init_runner()
module._replace_quant_method(module.quant_method.old_quant_method)
prepare_communication_buffer_for_model(self.worker.model_runner.model)
eplb_model_state.communicator = create_eplb_communicator(
......
......@@ -595,10 +595,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
@property
def _shared_experts(self):
return self.base_layer._shared_experts
@property
def quant_method(self):
return self.base_layer.quant_method
......
......@@ -937,6 +937,15 @@ class FusedMoEParallelConfig:
all2all_backend: str # all2all backend for MoE communication
enable_eplb: bool # whether to enable expert load balancing
@property
def use_dp_chunking(self) -> bool:
return (
self.use_deepep_ll_kernels
or self.use_mori_kernels
or self.use_fi_nvl_two_sided_kernels
or self.use_nixl_ep_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@property
def is_sequence_parallel(self) -> bool:
return self.sp_size > 1
......
......@@ -1194,6 +1194,8 @@ def cutlass_moe_w4a8_fp8(
quant_config=quant_config,
group_size=group_size,
),
shared_experts=None,
inplace=False,
)
return fn.apply(
......
......@@ -53,6 +53,7 @@ class TrtLlmFp8ExpertsBase:
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
self.moe_config = moe_config
self.quant_config = quant_config
@staticmethod
......
......@@ -40,9 +40,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def mk_owns_shared_expert(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
return (
self.moe_kernel is not None and self.moe_kernel.shared_experts is not None
)
return self.moe_kernel is not None and self.moe_kernel.owns_shared_experts
@abstractmethod
def create_weights(
......@@ -163,7 +161,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
raise NotImplementedError
def apply_monolithic(
......@@ -171,5 +169,5 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
raise NotImplementedError
......@@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEKernel,
FusedMoEPrepareAndFinalizeModular,
)
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
)
logger = init_logger(__name__)
......@@ -44,7 +47,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
moe_layer: torch.nn.Module,
old_quant_method: FusedMoEMethodBase,
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
shared_experts: torch.nn.Module | None,
shared_experts: SharedExperts | None,
inplace: bool = False,
) -> "FusedMoEModularMethod":
return FusedMoEModularMethod(
......@@ -52,8 +55,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
FusedMoEKernel(
prepare_finalize,
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts,
moe_parallel_config=moe_layer.moe_parallel_config,
shared_experts=shared_experts,
inplace=inplace,
),
)
......@@ -89,7 +91,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
......
......@@ -42,6 +42,9 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import (
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import (
DefaultMoERunner,
)
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
......@@ -275,8 +278,6 @@ class FusedMoE(CustomOp):
):
super().__init__()
self._gate = gate
self._shared_experts = shared_experts
self._routed_input_transform = routed_input_transform
if params_dtype is None:
......@@ -486,7 +487,7 @@ class FusedMoE(CustomOp):
device=vllm_config.device_config.device,
routing_method=self.routing_method_type,
# TODO: in_dtype == out_dtype?
disable_inplace=disable_inplace() or self._shared_experts is not None,
disable_inplace=disable_inplace() or shared_experts is not None,
)
if self.moe_config.use_mori_kernels:
assert self.rocm_aiter_fmoe_enabled, (
......@@ -564,34 +565,20 @@ class FusedMoE(CustomOp):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.base_quant_method = self.quant_method
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
# - we are using flashinfer with DP, since there nothing to gain
# - we are using marlin kernels
backend = self.moe_parallel_config.all2all_backend
self.use_overlapped = (
not (
(self.enable_eplb and backend != "allgather_reducescatter")
or self.moe_parallel_config.use_fi_nvl_two_sided_kernels
)
and self._shared_experts is not None
)
self.runner = self._init_runner()
# TODO(bnell): this is un-needed and removed in a follow up PR.
self.base_quant_method = self.quant_method
def _init_runner(self):
# Storing the runner in the FusedMoE is an intermediate state, eventually
# the runner will own the FusedMoE layer and provide the execution interface
# for MoE ops.
return DefaultMoERunner(
self.runner = DefaultMoERunner(
layer=self,
moe_config=self.moe_config,
router=self.router,
routed_input_transform=self._routed_input_transform,
gate=self.gate,
shared_experts=self.shared_experts,
gate=gate,
shared_experts=shared_experts,
quant_method=self.quant_method,
reduce_results=self.reduce_results,
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
......@@ -602,10 +589,7 @@ class FusedMoE(CustomOp):
# intrusive way to do this.
def _replace_quant_method(self, mk: FusedMoEMethodBase):
self.quant_method = mk
# We need to force reconstruction of runner because we're swapping out
# the quant_method with a FusedMoEModularMethod. This logic can go
# away once the FusedMoEModularMethod is eliminated.
self.runner = self._init_runner()
self.runner._replace_quant_method(mk)
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
......@@ -639,8 +623,8 @@ class FusedMoE(CustomOp):
)
@property
def shared_experts(self) -> torch.nn.Module | None:
return self._shared_experts if self.use_overlapped else None
def shared_experts(self) -> SharedExperts | None:
return self.runner.shared_experts
@property
def layer_id(self):
......@@ -649,10 +633,6 @@ class FusedMoE(CustomOp):
return extract_layer_index(self.layer_name)
@property
def gate(self) -> torch.nn.Module | None:
return self._gate if self.use_overlapped else None
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
......@@ -676,7 +656,7 @@ class FusedMoE(CustomOp):
@property
def is_internal_router(self) -> bool:
# By default, router/gate is called before FusedMoE forward pass
return self.gate is not None
return self.runner.is_internal_router()
def _maybe_init_expert_routing_tables(
self,
......@@ -1467,7 +1447,12 @@ class FusedMoE(CustomOp):
assert all(
weight.is_contiguous()
for name, weight in weights
if not (name.startswith("_shared_experts.") or name.startswith("_gate."))
if not (
name.startswith("_shared_experts.")
or name.startswith("_gate.")
or name.startswith("_routed_input_transform.")
or name.startswith("_routed_output_transform.")
)
and name not in NON_EXPERT_WEIGHTS
)
......@@ -1477,8 +1462,11 @@ class FusedMoE(CustomOp):
if name not in NON_EXPERT_WEIGHTS
and weight.shape != torch.Size([])
and not name.startswith("_shared_experts.")
# exclude parameters from non-expert submodules (e.g. gate/shared)
# exclude parameters from non-expert submodules,
# e.g. gate/shared/transforms.
and not name.startswith("_gate.")
and not name.startswith("_routed_input_transform.")
and not name.startswith("_routed_output_transform.")
]
def set_eplb_state(
......
......@@ -21,6 +21,10 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
SharedExpertsOrder,
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
disable_inplace,
......@@ -235,6 +239,13 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
raise NotImplementedError
def supports_async(self) -> bool:
"""
Indicates whether or not this class implements prepare_async and
finalize_async.
"""
return False
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
......@@ -281,13 +292,6 @@ class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
"""
raise NotImplementedError
def supports_async(self) -> bool:
"""
Indicates whether or not this class implements prepare_async and
finalize_async.
"""
return False
def prepare_async(
self,
a1: torch.Tensor,
......@@ -1003,15 +1007,20 @@ class FusedMoEKernelModularImpl:
self,
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
fused_experts: FusedMoEExpertsModular,
shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
shared_experts: SharedExperts | None,
inplace: bool = False,
):
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
self.moe_parallel_config = moe_parallel_config
# Only accept shared experts if they can be run w/async.
# The MoERunner/SharedExperts class will coordinate with the MK to ensure
# that the SharedExperts are executed only once.
self.shared_experts = (
shared_experts if prepare_finalize.supports_async() else None
)
self.inplace = inplace
moe_parallel_config = fused_experts.moe_config.moe_parallel_config
self.moe_parallel_config = moe_parallel_config
self.is_dp_ep = (
moe_parallel_config is not None
and moe_parallel_config.dp_size > 1
......@@ -1081,6 +1090,17 @@ class FusedMoEKernelModularImpl:
return workspace13, workspace2, fused_out
def _maybe_apply_shared_experts(
self,
shared_experts_input: torch.Tensor | None,
):
if self.shared_experts is not None:
assert shared_experts_input is not None
self.shared_experts.apply(
shared_experts_input,
SharedExpertsOrder.MK_INTERNAL_OVERLAPPED,
)
def _prepare(
self,
hidden_states: torch.Tensor,
......@@ -1253,15 +1273,6 @@ class FusedMoEKernelModularImpl:
shared_experts_input is the original hidden_states (full
dimension) needed by the shared expert MLP.
"""
shared_output: torch.Tensor | None = None
# For latent MoE: shared experts need the original hidden_states
# (full hidden_size), not the latent-projected version used by
# routed experts.
se_hidden_states = (
shared_experts_input if shared_experts_input is not None else hidden_states
)
if not self.prepare_finalize.supports_async():
assert not dbo_enabled()
......@@ -1273,8 +1284,6 @@ class FusedMoEKernelModularImpl:
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(se_hidden_states)
else:
finalize_ret = self.prepare_finalize.finalize_async(
output,
......@@ -1284,8 +1293,7 @@ class FusedMoEKernelModularImpl:
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(se_hidden_states)
self._maybe_apply_shared_experts(shared_experts_input)
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
......@@ -1308,11 +1316,7 @@ class FusedMoEKernelModularImpl:
receiver()
if self.shared_experts is None:
return output
else:
assert shared_output is not None
return shared_output, output
return output
def apply(
self,
......@@ -1326,7 +1330,7 @@ class FusedMoEKernelModularImpl:
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
shared_experts_input: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
......@@ -1469,12 +1473,10 @@ class FusedMoEKernel:
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEExperts,
shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
shared_experts: SharedExperts | None = None,
inplace: bool = False,
):
super().__init__()
self.shared_experts = shared_experts # NOTE: check if we can remove
# Initialize the implementation (monolithic or modular).
self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl
......@@ -1485,14 +1487,12 @@ class FusedMoEKernel:
prepare_finalize,
fused_experts,
shared_experts,
moe_parallel_config,
inplace,
)
elif isinstance(
prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic
) and isinstance(fused_experts, FusedMoEExpertsMonolithic):
assert shared_experts is None
assert not inplace
self.impl = FusedMoEKernelMonolithicImpl(
prepare_finalize,
......@@ -1508,6 +1508,13 @@ class FusedMoEKernel:
self._post_init_setup()
@property
def owns_shared_experts(self) -> bool:
if isinstance(self.impl, FusedMoEKernelModularImpl):
return self.impl.shared_experts is not None
else:
return False
@property
def is_monolithic(self) -> bool:
return isinstance(self.impl, FusedMoEKernelMonolithicImpl)
......
......@@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
......@@ -545,7 +548,7 @@ def make_fp8_moe_kernel(
experts_cls: type[mk.FusedMoEExperts],
fp8_backend: Fp8MoeBackend,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
shared_experts: SharedExperts | None = None,
) -> mk.FusedMoEKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
......@@ -581,12 +584,7 @@ def make_fp8_moe_kernel(
kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_deepep_ll_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
shared_experts=shared_experts,
inplace=(
not moe_config.disable_inplace
and fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
......
......@@ -859,7 +859,6 @@ def make_mxfp4_moe_kernel(
if moe_config.moe_parallel_config.use_deepep_ll_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
inplace=(
not moe_config.disable_inplace and mxfp4_backend not in TRTLLM_BACKENDS
),
......
......@@ -17,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
)
......@@ -386,7 +389,7 @@ def make_nvfp4_moe_kernel(
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEExperts],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
shared_experts: SharedExperts | None = None,
) -> mk.FusedMoEKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
......@@ -422,12 +425,7 @@ def make_nvfp4_moe_kernel(
kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_deepep_ll_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
shared_experts=shared_experts,
inplace=False,
)
......
......@@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
convert_moe_weights_to_flashinfer_trtllm_block_layout,
......@@ -321,7 +324,7 @@ def make_unquantized_moe_kernel(
backend: UnquantizedMoeBackend,
experts_cls: type[mk.FusedMoEExperts],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
shared_experts: SharedExperts | None = None,
) -> mk.FusedMoEKernel:
# Create Prepare/Finalize
is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic)
......@@ -355,12 +358,7 @@ def make_unquantized_moe_kernel(
kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_deepep_ll_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
shared_experts=shared_experts,
inplace=(not moe_config.disable_inplace and not is_monolithic),
)
......
......@@ -325,7 +325,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=qc_a1_gscale_or_scale)
if qc_a1_gscale_or_scale is not None
if qc_a1_gscale_or_scale is not None and nvfp4_dispatch
else dict()
),
async_finish=False,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
import torch
......@@ -13,6 +14,13 @@ class FusedMoERouter(ABC):
method that is used for routing hidden states based on router logits.
"""
@abstractmethod
def set_capture_fn(
self,
capture_fn: Callable[[torch.Tensor], None] | None,
) -> None:
raise NotImplementedError
@property
@abstractmethod
def routing_method_type(self) -> RoutingMethodType:
......
......@@ -32,3 +32,7 @@ class MoERunner(ABC):
final_hidden_states: torch.Tensor,
):
raise NotImplementedError
@abstractmethod
def is_internal_router(self) -> bool:
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import IntEnum
import torch
import vllm.envs as envs
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
aux_stream,
current_stream,
)
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
)
logger = init_logger(__name__)
class SharedExpertsOrder(IntEnum):
# No shared experts.
NONE = (0,)
# Get rid of this one? combine with BEFORE?
# Note: this might be important for torch.compile reasons. Can
# get rid of it after _moe_forward is undone.
EXTERNAL = (1,)
# No overlap - defensively called before MK.
NO_OVERLAP = (2,)
# Overlapped with dispatch/combine in DP/EP - called by the MK.
MK_INTERNAL_OVERLAPPED = (3,)
# Overlapped with the gate, router, experts in aux stream.
MULTI_STREAM_OVERLAPPED = (4,)
class SharedExperts:
def __init__(
self,
layer: torch.nn.Module,
moe_config: FusedMoEConfig,
quant_method: QuantizeMethodBase,
reduce_results: bool,
enable_dbo: bool,
):
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
# quant_method must be a FusedMoEMethodBase but we can't use the type
# due to circular imports.
assert isinstance(quant_method, FusedMoEMethodBase)
# The SharedExperts need to handle DBO since they can be called from
# an MK's finalize method. We keep a list of outputs indexed by current
# DBO ubatch id to handle this case. If DBO is not enabled, the
# index is always 0 and the second output list element is ignored.
self.enable_dbo = enable_dbo
self._output: list[torch.Tensor | None] = [None, None]
self._layer = layer
self._moe_config = moe_config
self._quant_method = quant_method
self._reduce_results = reduce_results
self._use_dp_chunking = moe_config.moe_parallel_config.use_dp_chunking
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
self._stream = None
else:
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self._stream = aux_stream()
if self._stream is not None:
logger.debug_once(
"Enabled separate cuda stream for MoE shared_experts", scope="local"
)
@property
def _has_external_experts(self) -> bool:
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
# - we are using flashinfer with DP, since there nothing to gain
backend = self._moe_config.moe_parallel_config.all2all_backend
return not (
(
self._moe_config.moe_parallel_config.enable_eplb
and backend != "allgather_reducescatter"
)
or self._moe_config.moe_parallel_config.use_fi_nvl_two_sided_kernels
)
def _determine_shared_experts_order(
self,
hidden_states: torch.Tensor,
) -> SharedExpertsOrder:
if self._has_external_experts and not self._use_dp_chunking:
return SharedExpertsOrder.EXTERNAL
if self._quant_method.mk_owns_shared_expert:
return SharedExpertsOrder.MK_INTERNAL_OVERLAPPED
should_run_shared_in_aux_stream = (
current_platform.is_cuda()
and not self._use_dp_chunking
and self._stream is not None
and hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
if should_run_shared_in_aux_stream:
return SharedExpertsOrder.MULTI_STREAM_OVERLAPPED
else:
return SharedExpertsOrder.NO_OVERLAP
def maybe_sync_shared_experts_stream(
self,
shared_experts_input: torch.Tensor,
):
experts_order = self._determine_shared_experts_order(shared_experts_input)
if experts_order == SharedExpertsOrder.MULTI_STREAM_OVERLAPPED:
assert self._stream is not None
assert self._moe_config.disable_inplace
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We don't need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
shared_experts_input.record_stream(self._stream)
# Mark sync start point for the aux stream since we will
# run in parallel with router/gate.
self._stream.wait_stream(current_stream())
def _run_in_aux_stream(
self,
shared_experts_input: torch.Tensor,
) -> torch.Tensor:
# TODO: assert that maybe_sync_shared_experts_stream has been called.
# Run shared experts in parallel on a separate stream.
with torch.cuda.stream(self._stream):
output = self._layer(shared_experts_input)
current_stream().wait_stream(self._stream)
return output
def _maybe_reduce_shared_out(self, shared_out: torch.Tensor) -> torch.Tensor:
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if (
self._reduce_results
and self._quant_method.moe_kernel is not None
and self._quant_method.moe_kernel.output_is_reduced()
and get_tensor_model_parallel_world_size() > 1
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out
@property
def _output_idx(self) -> int:
return dbo_current_ubatch_id() if self.enable_dbo else 0
@property
def output(self) -> torch.Tensor:
assert self._output[self._output_idx] is not None
output = self._output[self._output_idx]
self._output[self._output_idx] = None
return output
def apply(
self,
shared_experts_input: torch.Tensor,
order: SharedExpertsOrder,
):
experts_order = self._determine_shared_experts_order(shared_experts_input)
if order != experts_order:
return None
assert self._output[self._output_idx] is None
if order == SharedExpertsOrder.MULTI_STREAM_OVERLAPPED:
self._output[self._output_idx] = self._run_in_aux_stream(
shared_experts_input
)
else:
self._output[self._output_idx] = self._layer(shared_experts_input)
if order == SharedExpertsOrder.EXTERNAL:
# TODO: figure out how to combine this with maybe_reduce_output?
# or get rid of it completely.
assert self._output[self._output_idx] is not None
self._output[self._output_idx] = self._maybe_reduce_shared_out(
self._output[self._output_idx]
)
assert self._output[self._output_idx] is not None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment