Unverified Commit 5a93b916 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Integrate Naive Prepare Finalize into MK (#32567)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarAmir Klein <203507526+amirkl94@users.noreply.github.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avataramirkl94 <203507526+amirkl94@users.noreply.github.com>
parent 6d86fde0
...@@ -1131,7 +1131,7 @@ steps: ...@@ -1131,7 +1131,7 @@ steps:
- csrc/quantization/cutlass_w8a8/moe/ - csrc/quantization/cutlass_w8a8/moe/
- vllm/model_executor/layers/fused_moe/cutlass_moe.py - vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py - vllm/v1/attention/backends/mla/cutlass_mla.py
......
...@@ -1017,7 +1017,7 @@ steps: ...@@ -1017,7 +1017,7 @@ steps:
- csrc/quantization/cutlass_w8a8/moe/ - csrc/quantization/cutlass_w8a8/moe/
- vllm/model_executor/layers/fused_moe/cutlass_moe.py - vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py - vllm/v1/attention/backends/mla/cutlass_mla.py
......
...@@ -85,7 +85,7 @@ steps: ...@@ -85,7 +85,7 @@ steps:
- csrc/quantization/cutlass_w8a8/moe/ - csrc/quantization/cutlass_w8a8/moe/
- vllm/model_executor/layers/fused_moe/cutlass_moe.py - vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py - vllm/v1/attention/backends/mla/cutlass_mla.py
......
...@@ -197,7 +197,7 @@ def bench_run( ...@@ -197,7 +197,7 @@ def bench_run(
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True), MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4( CutlassExpertsFp4(
make_dummy_moe_config(), make_dummy_moe_config(),
quant_config=quant_config, quant_config=quant_config,
...@@ -242,7 +242,7 @@ def bench_run( ...@@ -242,7 +242,7 @@ def bench_run(
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True), MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4( CutlassExpertsFp4(
make_dummy_moe_config(), make_dummy_moe_config(),
quant_config=quant_config, quant_config=quant_config,
......
...@@ -36,8 +36,7 @@ th { ...@@ -36,8 +36,7 @@ th {
| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] | | pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] |
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] | | deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] | | deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferAllToAllMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferAllToAllMoEPrepareAndFinalize] | | flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] |
| flashinfer<sup>4</sup> | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] |
| MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] | | MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
| BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] | | BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |
......
...@@ -22,6 +22,9 @@ from vllm.distributed import ( ...@@ -22,6 +22,9 @@ from vllm.distributed import (
) )
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -40,7 +43,6 @@ from .mk_objects import ( ...@@ -40,7 +43,6 @@ from .mk_objects import (
TestMoEQuantConfig, TestMoEQuantConfig,
expert_info, expert_info,
make_fused_experts, make_fused_experts,
make_prepare_finalize,
prepare_finalize_info, prepare_finalize_info,
) )
from .parallel_utils import ProcessGroupInfo from .parallel_utils import ProcessGroupInfo
...@@ -603,10 +605,12 @@ def make_modular_kernel( ...@@ -603,10 +605,12 @@ def make_modular_kernel(
routing_method=RoutingMethodType.DeepSeekV3, routing_method=RoutingMethodType.DeepSeekV3,
) )
# make modular kernel prepare_finalize = maybe_make_prepare_finalize(
prepare_finalize = make_prepare_finalize( moe=moe,
config.prepare_finalize_type, config.all2all_backend(), moe, quant_config quant_config=quant_config,
allow_new_interface=True,
) )
assert prepare_finalize is not None
fused_experts = make_fused_experts( fused_experts = make_fused_experts(
config.fused_experts_type, config.fused_experts_type,
......
...@@ -7,9 +7,6 @@ import torch ...@@ -7,9 +7,6 @@ import torch
# Fused experts and PrepareFinalize imports # Fused experts and PrepareFinalize imports
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts, BatchedDeepGemmExperts,
) )
...@@ -255,13 +252,12 @@ if has_pplx(): ...@@ -255,13 +252,12 @@ if has_pplx():
) )
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100): if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, FlashInferExperts,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
create_flashinfer_prepare_finalize,
)
register_prepare_and_finalize( register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize, FlashInferCutlassMoEPrepareAndFinalize,
...@@ -429,24 +425,6 @@ if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe(): ...@@ -429,24 +425,6 @@ if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
] ]
def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: str | None,
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return create_flashinfer_prepare_finalize(
use_dp=moe.moe_parallel_config.dp_size > 1
)
else:
return MoEPrepareAndFinalizeNoEP()
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor: def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
s = rank * num_local_experts s = rank * num_local_experts
e = s + num_local_experts e = s + num_local_experts
......
...@@ -294,12 +294,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -294,12 +294,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP( MoEPrepareAndFinalizeNoEP(),
defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
moe_config=moe_config,
quant_config=quant_config,
)
),
FlashInferExperts( FlashInferExperts(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
......
...@@ -106,12 +106,7 @@ def test_flashinfer_fp4_moe_no_graph( ...@@ -106,12 +106,7 @@ def test_flashinfer_fp4_moe_no_graph(
) )
flashinfer_experts = FusedMoEModularKernel( flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP( MoEPrepareAndFinalizeNoEP(),
defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
moe_config=moe_config,
quant_config=quant_config,
)
),
FlashInferExperts(moe_config=moe_config, quant_config=quant_config), FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
) )
......
...@@ -90,7 +90,7 @@ def test_cutlass_fp4_moe_no_graph( ...@@ -90,7 +90,7 @@ def test_cutlass_fp4_moe_no_graph(
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True), MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4( CutlassExpertsFp4(
moe_config=make_dummy_moe_config(), moe_config=make_dummy_moe_config(),
quant_config=quant_config, quant_config=quant_config,
......
...@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return buffer return buffer
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -84,6 +84,34 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -84,6 +84,34 @@ class NaiveAll2AllManager(All2AllManagerBase):
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_weights = self.naive_multicast(
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_ids = self.naive_multicast(
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -114,7 +142,7 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -114,7 +142,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group): def __init__(self, cpu_group):
super().__init__(cpu_group) super().__init__(cpu_group)
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -148,6 +176,46 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -148,6 +176,46 @@ class AgRsAll2AllManager(All2AllManagerBase):
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1] return gathered_tensors[0], gathered_tensors[1]
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
hidden_states = gathered_tensors[0]
topk_weights = gathered_tensors[1]
topk_ids = gathered_tensors[2]
if extra_tensors is None:
return hidden_states, topk_weights, topk_ids
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -216,7 +284,7 @@ class PPLXAll2AllManager(All2AllManagerBase): ...@@ -216,7 +284,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode, pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
) )
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -225,6 +293,19 @@ class PPLXAll2AllManager(All2AllManagerBase): ...@@ -225,6 +293,19 @@ class PPLXAll2AllManager(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -264,7 +345,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -264,7 +345,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs): def get_handle(self, kwargs):
raise NotImplementedError raise NotImplementedError
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -273,6 +354,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -273,6 +354,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading import threading
from typing import Any
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
import torch import torch
...@@ -64,13 +63,32 @@ class All2AllManagerBase: ...@@ -64,13 +63,32 @@ class All2AllManagerBase:
# and reuse it for the same config. # and reuse it for the same config.
raise NotImplementedError raise NotImplementedError
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> Any: ) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either: # Subclasses should either:
# - implement handling for extra_tensors, or # - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported. # - raise a clear error if extra_tensors is not supported.
...@@ -280,7 +298,7 @@ class DeviceCommunicatorBase: ...@@ -280,7 +298,7 @@ class DeviceCommunicatorBase:
for module in moe_modules: for module in moe_modules:
module.maybe_init_modular_kernel() module.maybe_init_modular_kernel()
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -294,8 +312,29 @@ class DeviceCommunicatorBase: ...@@ -294,8 +312,29 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device. Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.
""" """
if extra_tensors is not None:
return hidden_states, router_logits, extra_tensors
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, topk_weights, topk_ids, extra_tensors
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -130,29 +130,65 @@ class CpuCommunicator(DeviceCommunicatorBase): ...@@ -130,29 +130,65 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]: ) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src) return self.dist_module.recv_tensor_dict(src)
def dispatch( # type: ignore[override] def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
extra_tensors, # type: ignore[call-arg] extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine( return self.all2all_manager.combine(
hidden_states, is_sequence_parallel hidden_states,
is_sequence_parallel,
) )
return hidden_states
class _CPUSHMDistributed: class _CPUSHMDistributed:
......
...@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list return output_list
def dispatch( # type: ignore[override] def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -332,19 +332,52 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -332,19 +332,52 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple[torch.Tensor, torch.Tensor] tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
): ):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
extra_tensors, # type: ignore[call-arg] extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine( return self.all2all_manager.combine(
hidden_states, is_sequence_parallel hidden_states,
is_sequence_parallel,
) )
return hidden_states
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch.distributed as dist import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend from flashinfer.comm.mnnvl import CommBackend as CommBackend
...@@ -23,5 +25,14 @@ class CustomCommunicator(CommBackend): ...@@ -23,5 +25,14 @@ class CustomCommunicator(CommBackend):
dist.all_gather_object(gathered, data, group=self._group) dist.all_gather_object(gathered, data, group=self._group)
return gathered return gathered
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
# are unimplemented on vLLM side. If we need to utilize these
# methods in the future, can create a concrete implementation.
def bcast(self, data: Any, root: int) -> Any:
raise NotImplementedError
def barrier(self) -> None:
raise NotImplementedError
def Split(self, color: int, key: int) -> "CustomCommunicator": def Split(self, color: int, key: int) -> "CustomCommunicator":
return self return self
...@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase): ...@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase):
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group) dist.broadcast(input_, src=src, group=self.device_group)
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
extra_tensors, # type: ignore[call-arg] extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine( return self.all2all_manager.combine(
hidden_states, is_sequence_parallel hidden_states,
is_sequence_parallel,
) )
return hidden_states
...@@ -1000,7 +1000,7 @@ class GroupCoordinator: ...@@ -1000,7 +1000,7 @@ class GroupCoordinator:
if self.device_communicator is not None: if self.device_communicator is not None:
self.device_communicator.prepare_communication_buffer_for_model(model) self.device_communicator.prepare_communication_buffer_for_model(model)
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -1011,7 +1011,7 @@ class GroupCoordinator: ...@@ -1011,7 +1011,7 @@ class GroupCoordinator:
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
): ):
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.dispatch( # type: ignore[call-arg] return self.device_communicator.dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
...@@ -1020,6 +1020,28 @@ class GroupCoordinator: ...@@ -1020,6 +1020,28 @@ class GroupCoordinator:
else: else:
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states, is_sequence_parallel: bool = False self, hidden_states, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -7,17 +7,27 @@ import torch ...@@ -7,17 +7,27 @@ import torch
from vllm.distributed import ( from vllm.distributed import (
get_ep_group, get_ep_group,
) )
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import (
FlashInferA2APrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNaiveEP,
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
logger = init_logger(__name__)
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
if has_pplx(): if has_pplx():
from .pplx_prepare_finalize import ( from .pplx_prepare_finalize import (
...@@ -70,20 +80,46 @@ def maybe_make_prepare_finalize( ...@@ -70,20 +80,46 @@ def maybe_make_prepare_finalize(
moe: FusedMoEConfig, moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None, quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
allow_new_interface: bool = False,
) -> FusedMoEPrepareAndFinalize | None: ) -> FusedMoEPrepareAndFinalize | None:
# NOTE(rob): we are migrating each quant_method to hold the MK
# in all cases. The allow_new_interface=False flag allow us to fall
# back to the old method for methods that have not yet been migrated.
#
# In old method:
# * maybe_init_modular_kernel() calls this function. If we are
# using no Dp/Ep or naive all2all, we return None this function
# returns None and no ModularKernelMethod is created. If non-naive
# all2all is used, this returns a PrepareAndFinalize object and
# a ModularKernelMethod is created.
# In new method:
# * maybe_make_prepare_finalize() is called from the oracle. We
# always return a PrepareAndFinalize object and the quant method
# holds the ModularKernel.
if not moe.moe_parallel_config.use_all2all_kernels: if not moe.moe_parallel_config.use_all2all_kernels:
if not allow_new_interface:
return None return None
# For DP/TP case, fall back to naive P/F.
if moe.moe_parallel_config.dp_size > 1:
logger.info_once(
"Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine."
)
return MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=(
get_ep_group().device_communicator.all2all_manager.world_size
),
)
else:
return MoEPrepareAndFinalizeNoEP()
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None prepare_finalize: FusedMoEPrepareAndFinalize | None = None
# TODO(rob): update this as part of the MoE refactor.
assert not moe.use_flashinfer_cutlass_kernels, (
"Must be created in modelopt.py or fp8.py"
)
if moe.use_pplx_kernels: if moe.use_pplx_kernels:
assert quant_config is not None assert quant_config is not None
...@@ -203,4 +239,16 @@ def maybe_make_prepare_finalize( ...@@ -203,4 +239,16 @@ def maybe_make_prepare_finalize(
use_fp8_dispatch=use_fp8_dispatch, use_fp8_dispatch=use_fp8_dispatch,
) )
elif moe.use_fi_all2allv_kernels:
assert quant_config is not None
prepare_finalize = FlashInferA2APrepareAndFinalize(
num_dispatchers=all2all_manager.world_size,
)
elif moe.use_naive_all2all_kernels and allow_new_interface:
prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
num_dispatchers=all2all_manager.world_size,
)
return prepare_finalize return prepare_finalize
...@@ -20,7 +20,6 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( ...@@ -20,7 +20,6 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
...@@ -862,6 +861,7 @@ class FusedMoEParallelConfig: ...@@ -862,6 +861,7 @@ class FusedMoEParallelConfig:
use_ep: bool # whether to use EP or not use_ep: bool # whether to use EP or not
all2all_backend: str # all2all backend for MoE communication all2all_backend: str # all2all backend for MoE communication
is_sequence_parallel: bool # whether sequence parallelism is used
enable_eplb: bool # whether to enable expert load balancing enable_eplb: bool # whether to enable expert load balancing
@property @property
...@@ -883,6 +883,12 @@ class FusedMoEParallelConfig: ...@@ -883,6 +883,12 @@ class FusedMoEParallelConfig:
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@property
def use_fi_all2allv_kernels(self):
return (
self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv"
)
@property @property
def use_batched_activation_format(self): def use_batched_activation_format(self):
return self.use_deepep_ll_kernels or self.use_pplx_kernels return self.use_deepep_ll_kernels or self.use_pplx_kernels
...@@ -1014,6 +1020,7 @@ class FusedMoEParallelConfig: ...@@ -1014,6 +1020,7 @@ class FusedMoEParallelConfig:
ep_rank=0, ep_rank=0,
use_ep=False, use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend, all2all_backend=vllm_parallel_config.all2all_backend,
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb, enable_eplb=vllm_parallel_config.enable_eplb,
) )
# DP + EP / TP + EP / DP + TP + EP # DP + EP / TP + EP / DP + TP + EP
...@@ -1033,6 +1040,7 @@ class FusedMoEParallelConfig: ...@@ -1033,6 +1040,7 @@ class FusedMoEParallelConfig:
ep_rank=ep_rank, ep_rank=ep_rank,
use_ep=True, use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend, all2all_backend=vllm_parallel_config.all2all_backend,
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb, enable_eplb=vllm_parallel_config.enable_eplb,
) )
...@@ -1051,6 +1059,7 @@ class FusedMoEParallelConfig: ...@@ -1051,6 +1059,7 @@ class FusedMoEParallelConfig:
use_ep=False, use_ep=False,
all2all_backend="naive", all2all_backend="naive",
enable_eplb=False, enable_eplb=False,
is_sequence_parallel=False,
) )
...@@ -1145,12 +1154,9 @@ class FusedMoEConfig: ...@@ -1145,12 +1154,9 @@ class FusedMoEConfig:
return self.moe_parallel_config.use_mori_kernels return self.moe_parallel_config.use_mori_kernels
@property @property
def use_flashinfer_cutlass_kernels(self): def use_fi_all2allv_kernels(self):
""" return self.moe_parallel_config.use_fi_all2allv_kernels
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
""" @property
return ( def use_naive_all2all_kernels(self):
envs.VLLM_USE_FLASHINFER_MOE_FP4 return self.moe_parallel_config.use_naive_all2all_kernels
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput"
)
...@@ -103,7 +103,14 @@ def run_cutlass_moe_fp8( ...@@ -103,7 +103,14 @@ def run_cutlass_moe_fp8(
or a2_scale.size(0) == a1q.shape[0] or a2_scale.size(0) == a1q.shape[0]
), "Intermediate scale shape mismatch" ), "Intermediate scale shape mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
if expert_map is not None:
# NOTE(rob): the expert_map is used for the STANDARD case and
# the batched format is used by the BATCHED case.
# TODO(rob): update the MK interface to only pass the expert_map
# during the STANDARD case to make this clearer across all kernels.
if use_batched_format:
assert expert_num_tokens is not None
else:
assert expert_num_tokens is None assert expert_num_tokens is None
# We have two modes: batched experts and non-batched experts. # We have two modes: batched experts and non-batched experts.
...@@ -379,7 +386,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): ...@@ -379,7 +386,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# needed for STANDARD activation format kernels in DP/EP mode. # needed for STANDARD activation format kernels in DP/EP mode.
# Note that the BATCHED activation format does not use # Note that the BATCHED activation format does not use
# the expert map for identifying experts. # the expert map for identifying experts.
return not moe_parallel_config.use_all2all_kernels return not (
moe_parallel_config.use_fi_all2allv_kernels
or moe_parallel_config.use_deepep_ht_kernels
)
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return True return True
...@@ -641,10 +651,8 @@ def run_cutlass_moe_fp4( ...@@ -641,10 +651,8 @@ def run_cutlass_moe_fp4(
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @property
def expects_unquantized_inputs( def expects_unquantized_inputs(self) -> bool:
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
return True return True
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment