Unverified Commit 5a93b916 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Integrate Naive Prepare Finalize into MK (#32567)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarAmir Klein <203507526+amirkl94@users.noreply.github.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avataramirkl94 <203507526+amirkl94@users.noreply.github.com>
parent 6d86fde0
......@@ -1131,7 +1131,7 @@ steps:
- csrc/quantization/cutlass_w8a8/moe/
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
......
......@@ -1017,7 +1017,7 @@ steps:
- csrc/quantization/cutlass_w8a8/moe/
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
......
......@@ -85,7 +85,7 @@ steps:
- csrc/quantization/cutlass_w8a8/moe/
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
......
......@@ -197,7 +197,7 @@ def bench_run(
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
make_dummy_moe_config(),
quant_config=quant_config,
......@@ -242,7 +242,7 @@ def bench_run(
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
make_dummy_moe_config(),
quant_config=quant_config,
......
......@@ -36,8 +36,7 @@ th {
| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] |
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferAllToAllMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferAllToAllMoEPrepareAndFinalize] |
| flashinfer<sup>4</sup> | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] |
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] |
| MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
| BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |
......
......@@ -22,6 +22,9 @@ from vllm.distributed import (
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -40,7 +43,6 @@ from .mk_objects import (
TestMoEQuantConfig,
expert_info,
make_fused_experts,
make_prepare_finalize,
prepare_finalize_info,
)
from .parallel_utils import ProcessGroupInfo
......@@ -603,10 +605,12 @@ def make_modular_kernel(
routing_method=RoutingMethodType.DeepSeekV3,
)
# make modular kernel
prepare_finalize = make_prepare_finalize(
config.prepare_finalize_type, config.all2all_backend(), moe, quant_config
prepare_finalize = maybe_make_prepare_finalize(
moe=moe,
quant_config=quant_config,
allow_new_interface=True,
)
assert prepare_finalize is not None
fused_experts = make_fused_experts(
config.fused_experts_type,
......
......@@ -7,9 +7,6 @@ import torch
# Fused experts and PrepareFinalize imports
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
......@@ -255,13 +252,12 @@ if has_pplx():
)
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
create_flashinfer_prepare_finalize,
)
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
......@@ -429,24 +425,6 @@ if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
]
def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: str | None,
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return create_flashinfer_prepare_finalize(
use_dp=moe.moe_parallel_config.dp_size > 1
)
else:
return MoEPrepareAndFinalizeNoEP()
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
s = rank * num_local_experts
e = s + num_local_experts
......
......@@ -294,12 +294,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
moe_config=moe_config,
quant_config=quant_config,
)
),
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
moe_config=moe_config,
quant_config=quant_config,
......
......@@ -106,12 +106,7 @@ def test_flashinfer_fp4_moe_no_graph(
)
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
moe_config=moe_config,
quant_config=quant_config,
)
),
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
)
......
......@@ -90,7 +90,7 @@ def test_cutlass_fp4_moe_no_graph(
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
......
......@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return buffer
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -84,6 +84,34 @@ class NaiveAll2AllManager(All2AllManagerBase):
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_weights = self.naive_multicast(
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_ids = self.naive_multicast(
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......@@ -114,7 +142,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -148,6 +176,46 @@ class AgRsAll2AllManager(All2AllManagerBase):
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1]
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
hidden_states = gathered_tensors[0]
topk_weights = gathered_tensors[1]
topk_ids = gathered_tensors[2]
if extra_tensors is None:
return hidden_states, topk_weights, topk_ids
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......@@ -216,7 +284,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -225,6 +293,19 @@ class PPLXAll2AllManager(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......@@ -264,7 +345,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs):
raise NotImplementedError
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -273,6 +354,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Any
from weakref import WeakValueDictionary
import torch
......@@ -64,13 +63,32 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise NotImplementedError
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> Any:
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
......@@ -280,7 +298,7 @@ class DeviceCommunicatorBase:
for module in moe_modules:
module.maybe_init_modular_kernel()
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -294,8 +312,29 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, router_logits, extra_tensors
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, topk_weights, topk_ids, extra_tensors
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......
......@@ -130,29 +130,65 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src)
def dispatch( # type: ignore[override]
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
return self.all2all_manager.combine(
hidden_states,
is_sequence_parallel,
)
return hidden_states
class _CPUSHMDistributed:
......
......@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list
def dispatch( # type: ignore[override]
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -332,19 +332,52 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
return self.all2all_manager.combine(
hidden_states,
is_sequence_parallel,
)
return hidden_states
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend
......@@ -23,5 +25,14 @@ class CustomCommunicator(CommBackend):
dist.all_gather_object(gathered, data, group=self._group)
return gathered
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
# are unimplemented on vLLM side. If we need to utilize these
# methods in the future, can create a concrete implementation.
def bcast(self, data: Any, root: int) -> Any:
raise NotImplementedError
def barrier(self) -> None:
raise NotImplementedError
def Split(self, color: int, key: int) -> "CustomCommunicator":
return self
......@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase):
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group)
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
return self.all2all_manager.combine(
hidden_states,
is_sequence_parallel,
)
return hidden_states
......@@ -1000,7 +1000,7 @@ class GroupCoordinator:
if self.device_communicator is not None:
self.device_communicator.prepare_communication_buffer_for_model(model)
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -1011,7 +1011,7 @@ class GroupCoordinator:
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch( # type: ignore[call-arg]
return self.device_communicator.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
......@@ -1020,6 +1020,28 @@ class GroupCoordinator:
else:
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states, is_sequence_parallel: bool = False
) -> torch.Tensor:
......
......@@ -7,17 +7,27 @@ import torch
from vllm.distributed import (
get_ep_group,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import (
FlashInferA2APrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNaiveEP,
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
logger = init_logger(__name__)
if current_platform.is_cuda_alike():
if has_pplx():
from .pplx_prepare_finalize import (
......@@ -70,20 +80,46 @@ def maybe_make_prepare_finalize(
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
allow_new_interface: bool = False,
) -> FusedMoEPrepareAndFinalize | None:
# NOTE(rob): we are migrating each quant_method to hold the MK
# in all cases. The allow_new_interface=False flag allow us to fall
# back to the old method for methods that have not yet been migrated.
#
# In old method:
# * maybe_init_modular_kernel() calls this function. If we are
# using no Dp/Ep or naive all2all, we return None this function
# returns None and no ModularKernelMethod is created. If non-naive
# all2all is used, this returns a PrepareAndFinalize object and
# a ModularKernelMethod is created.
# In new method:
# * maybe_make_prepare_finalize() is called from the oracle. We
# always return a PrepareAndFinalize object and the quant method
# holds the ModularKernel.
if not moe.moe_parallel_config.use_all2all_kernels:
return None
if not allow_new_interface:
return None
# For DP/TP case, fall back to naive P/F.
if moe.moe_parallel_config.dp_size > 1:
logger.info_once(
"Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine."
)
return MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=(
get_ep_group().device_communicator.all2all_manager.world_size
),
)
else:
return MoEPrepareAndFinalizeNoEP()
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
# TODO(rob): update this as part of the MoE refactor.
assert not moe.use_flashinfer_cutlass_kernels, (
"Must be created in modelopt.py or fp8.py"
)
if moe.use_pplx_kernels:
assert quant_config is not None
......@@ -203,4 +239,16 @@ def maybe_make_prepare_finalize(
use_fp8_dispatch=use_fp8_dispatch,
)
elif moe.use_fi_all2allv_kernels:
assert quant_config is not None
prepare_finalize = FlashInferA2APrepareAndFinalize(
num_dispatchers=all2all_manager.world_size,
)
elif moe.use_naive_all2all_kernels and allow_new_interface:
prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
num_dispatchers=all2all_manager.world_size,
)
return prepare_finalize
......@@ -20,7 +20,6 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import cdiv
......@@ -862,6 +861,7 @@ class FusedMoEParallelConfig:
use_ep: bool # whether to use EP or not
all2all_backend: str # all2all backend for MoE communication
is_sequence_parallel: bool # whether sequence parallelism is used
enable_eplb: bool # whether to enable expert load balancing
@property
......@@ -883,6 +883,12 @@ class FusedMoEParallelConfig:
def use_deepep_ll_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@property
def use_fi_all2allv_kernels(self):
return (
self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv"
)
@property
def use_batched_activation_format(self):
return self.use_deepep_ll_kernels or self.use_pplx_kernels
......@@ -1014,6 +1020,7 @@ class FusedMoEParallelConfig:
ep_rank=0,
use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend,
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb,
)
# DP + EP / TP + EP / DP + TP + EP
......@@ -1033,6 +1040,7 @@ class FusedMoEParallelConfig:
ep_rank=ep_rank,
use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend,
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb,
)
......@@ -1051,6 +1059,7 @@ class FusedMoEParallelConfig:
use_ep=False,
all2all_backend="naive",
enable_eplb=False,
is_sequence_parallel=False,
)
......@@ -1145,12 +1154,9 @@ class FusedMoEConfig:
return self.moe_parallel_config.use_mori_kernels
@property
def use_flashinfer_cutlass_kernels(self):
"""
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
"""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput"
)
def use_fi_all2allv_kernels(self):
return self.moe_parallel_config.use_fi_all2allv_kernels
@property
def use_naive_all2all_kernels(self):
return self.moe_parallel_config.use_naive_all2all_kernels
......@@ -103,7 +103,14 @@ def run_cutlass_moe_fp8(
or a2_scale.size(0) == a1q.shape[0]
), "Intermediate scale shape mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
if expert_map is not None:
# NOTE(rob): the expert_map is used for the STANDARD case and
# the batched format is used by the BATCHED case.
# TODO(rob): update the MK interface to only pass the expert_map
# during the STANDARD case to make this clearer across all kernels.
if use_batched_format:
assert expert_num_tokens is not None
else:
assert expert_num_tokens is None
# We have two modes: batched experts and non-batched experts.
......@@ -379,7 +386,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# needed for STANDARD activation format kernels in DP/EP mode.
# Note that the BATCHED activation format does not use
# the expert map for identifying experts.
return not moe_parallel_config.use_all2all_kernels
return not (
moe_parallel_config.use_fi_all2allv_kernels
or moe_parallel_config.use_deepep_ht_kernels
)
def supports_chunking(self) -> bool:
return True
......@@ -641,10 +651,8 @@ def run_cutlass_moe_fp4(
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def expects_unquantized_inputs(
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment