Commit bc387d5a authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1 (fused_moe)

parent 899a2db4
...@@ -7,27 +7,17 @@ import torch ...@@ -7,27 +7,17 @@ import torch
from vllm.distributed import ( from vllm.distributed import (
get_ep_group, get_ep_group,
) )
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import (
FlashInferA2APrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNaiveEP,
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
logger = init_logger(__name__)
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
if has_pplx(): if has_pplx():
from .pplx_prepare_finalize import ( from .pplx_prepare_finalize import (
...@@ -80,46 +70,20 @@ def maybe_make_prepare_finalize( ...@@ -80,46 +70,20 @@ def maybe_make_prepare_finalize(
moe: FusedMoEConfig, moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None, quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
allow_new_interface: bool = False,
) -> FusedMoEPrepareAndFinalize | None: ) -> FusedMoEPrepareAndFinalize | None:
# NOTE(rob): we are migrating each quant_method to hold the MK
# in all cases. The allow_new_interface=False flag allow us to fall
# back to the old method for methods that have not yet been migrated.
#
# In old method:
# * maybe_init_modular_kernel() calls this function. If we are
# using no Dp/Ep or naive all2all, we return None this function
# returns None and no ModularKernelMethod is created. If non-naive
# all2all is used, this returns a PrepareAndFinalize object and
# a ModularKernelMethod is created.
# In new method:
# * maybe_make_prepare_finalize() is called from the oracle. We
# always return a PrepareAndFinalize object and the quant method
# holds the ModularKernel.
if not moe.moe_parallel_config.use_all2all_kernels: if not moe.moe_parallel_config.use_all2all_kernels:
if not allow_new_interface: return None
return None
# For DP/TP case, fall back to naive P/F.
if moe.moe_parallel_config.dp_size > 1:
logger.info_once(
"Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine."
)
return MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=(
get_ep_group().device_communicator.all2all_manager.world_size
),
)
else:
return MoEPrepareAndFinalizeNoEP()
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None prepare_finalize: FusedMoEPrepareAndFinalize | None = None
# TODO(rob): update this as part of the MoE refactor.
assert not moe.use_flashinfer_cutlass_kernels, (
"Must be created in modelopt.py or fp8.py"
)
if moe.use_pplx_kernels: if moe.use_pplx_kernels:
assert quant_config is not None assert quant_config is not None
...@@ -239,16 +203,4 @@ def maybe_make_prepare_finalize( ...@@ -239,16 +203,4 @@ def maybe_make_prepare_finalize(
use_fp8_dispatch=use_fp8_dispatch, use_fp8_dispatch=use_fp8_dispatch,
) )
elif moe.use_fi_all2allv_kernels: return prepare_finalize
assert quant_config is not None \ No newline at end of file
prepare_finalize = FlashInferA2APrepareAndFinalize(
num_dispatchers=all2all_manager.world_size,
)
elif moe.use_naive_all2all_kernels and allow_new_interface:
prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
num_dispatchers=all2all_manager.world_size,
)
return prepare_finalize
...@@ -20,6 +20,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( ...@@ -20,6 +20,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
...@@ -892,13 +893,7 @@ class FusedMoEParallelConfig: ...@@ -892,13 +893,7 @@ class FusedMoEParallelConfig:
@property @property
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@property
def use_fi_all2allv_kernels(self):
return (
self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv"
)
@property @property
def use_batched_activation_format(self): def use_batched_activation_format(self):
return self.use_deepep_ll_kernels or self.use_pplx_kernels return self.use_deepep_ll_kernels or self.use_pplx_kernels
...@@ -1030,7 +1025,6 @@ class FusedMoEParallelConfig: ...@@ -1030,7 +1025,6 @@ class FusedMoEParallelConfig:
ep_rank=0, ep_rank=0,
use_ep=False, use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend, all2all_backend=vllm_parallel_config.all2all_backend,
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb, enable_eplb=vllm_parallel_config.enable_eplb,
) )
# DP + EP / TP + EP / DP + TP + EP # DP + EP / TP + EP / DP + TP + EP
...@@ -1050,7 +1044,6 @@ class FusedMoEParallelConfig: ...@@ -1050,7 +1044,6 @@ class FusedMoEParallelConfig:
ep_rank=ep_rank, ep_rank=ep_rank,
use_ep=True, use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend, all2all_backend=vllm_parallel_config.all2all_backend,
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb, enable_eplb=vllm_parallel_config.enable_eplb,
) )
...@@ -1069,7 +1062,6 @@ class FusedMoEParallelConfig: ...@@ -1069,7 +1062,6 @@ class FusedMoEParallelConfig:
use_ep=False, use_ep=False,
all2all_backend="naive", all2all_backend="naive",
enable_eplb=False, enable_eplb=False,
is_sequence_parallel=False,
) )
...@@ -1164,9 +1156,12 @@ class FusedMoEConfig: ...@@ -1164,9 +1156,12 @@ class FusedMoEConfig:
return self.moe_parallel_config.use_mori_kernels return self.moe_parallel_config.use_mori_kernels
@property @property
def use_fi_all2allv_kernels(self): def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_fi_all2allv_kernels """
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
@property """
def use_naive_all2all_kernels(self): return (
return self.moe_parallel_config.use_naive_all2all_kernels envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput"
)
...@@ -103,14 +103,7 @@ def run_cutlass_moe_fp8( ...@@ -103,14 +103,7 @@ def run_cutlass_moe_fp8(
or a2_scale.size(0) == a1q.shape[0] or a2_scale.size(0) == a1q.shape[0]
), "Intermediate scale shape mismatch" ), "Intermediate scale shape mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
if expert_map is not None:
# NOTE(rob): the expert_map is used for the STANDARD case and
# the batched format is used by the BATCHED case.
# TODO(rob): update the MK interface to only pass the expert_map
# during the STANDARD case to make this clearer across all kernels.
if use_batched_format:
assert expert_num_tokens is not None
else:
assert expert_num_tokens is None assert expert_num_tokens is None
# We have two modes: batched experts and non-batched experts. # We have two modes: batched experts and non-batched experts.
...@@ -386,10 +379,7 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): ...@@ -386,10 +379,7 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# needed for STANDARD activation format kernels in DP/EP mode. # needed for STANDARD activation format kernels in DP/EP mode.
# Note that the BATCHED activation format does not use # Note that the BATCHED activation format does not use
# the expert map for identifying experts. # the expert map for identifying experts.
return not ( return not moe_parallel_config.use_all2all_kernels
moe_parallel_config.use_fi_all2allv_kernels
or moe_parallel_config.use_deepep_ht_kernels
)
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return True return True
...@@ -651,8 +641,10 @@ def run_cutlass_moe_fp4( ...@@ -651,8 +641,10 @@ def run_cutlass_moe_fp4(
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
@property @staticmethod
def expects_unquantized_inputs(self) -> bool: def expects_unquantized_inputs(
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
return True return True
@staticmethod @staticmethod
...@@ -1177,4 +1169,4 @@ def cutlass_moe_w4a8_fp8( ...@@ -1177,4 +1169,4 @@ def cutlass_moe_w4a8_fp8(
global_num_experts=num_experts, global_num_experts=num_experts,
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
\ No newline at end of file
...@@ -148,8 +148,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -148,8 +148,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# NOTE(rob): discovered an IMA with this combination. Needs investigation. return True
return not moe_parallel_config.use_fi_all2allv_kernels
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return True return True
...@@ -308,4 +307,4 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -308,4 +307,4 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
inv_perm=inv_perm, inv_perm=inv_perm,
expert_map=expert_map, expert_map=expert_map,
output=output, output=output,
) )
\ No newline at end of file
...@@ -103,7 +103,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -103,7 +103,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts: int, num_experts: int,
a1_scale: torch.Tensor | None, a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> Callable: ) -> Callable:
has_scales = token_scales is not None has_scales = token_scales is not None
...@@ -175,7 +174,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -175,7 +174,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights, expert_topk_weights,
a1_scale, a1_scale,
quant_config, quant_config,
defer_input_quant=defer_input_quant,
) )
def _receiver( def _receiver(
...@@ -189,7 +187,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -189,7 +187,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights: torch.Tensor | None, expert_topk_weights: torch.Tensor | None,
a1_scale: torch.Tensor | None, a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
if event.event is not None: if event.event is not None:
event.current_stream_wait() event.current_stream_wait()
...@@ -224,15 +221,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -224,15 +221,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_num_tokens_per_expert_list, device=expert_x.device expert_num_tokens_per_expert_list, device=expert_x.device
) )
# * For non-block quant, dispatch in b16 and quantize now as # Dispatch and Quant
# DeepEP kernels only support dispatching block scales. # DeepEP kernels only support dispatching block-quantized
# * For expert kernels that require unquantized inputs, # activation scales.
# defer quantization to FusedMoEExpertsPermuteUnpermute. # Dispatch in bfloat16 and quantize afterwards
if not quant_config.is_block_quantized and not defer_input_quant: if not quant_config.is_block_quantized:
# Quantize after dispatch. # Quantize after dispatch.
expert_x_scale = None expert_x_scale = None
if expert_x.numel() != 0: if expert_x.numel() != 0:
# TODO: support per_act_token_quant,
expert_x, expert_x_scale = moe_kernel_quantize_input( expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x, expert_x,
a1_scale, a1_scale,
...@@ -261,7 +257,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -261,7 +257,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.ReceiverType: ) -> mk.ReceiverType:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
...@@ -271,12 +266,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -271,12 +266,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
# * DeepEP only supports fp8 block scales so quantize if quant_config.is_block_quantized:
# before the dispatch for these models. # Quant and Dispatch
# * For all other quantization, dispatch after.
# * For expert kernels that require unquantized inputs,
# defer quantization to FusedMoEExpertsPermuteUnpermute.
if quant_config.is_block_quantized and not defer_input_quant:
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
quant_config.a1_scale, quant_config.a1_scale,
...@@ -290,11 +281,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -290,11 +281,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else: else:
a1q = a1 a1q = a1
a1q_scale = None a1q_scale = None
a1_post_scale = ( a1_post_scale = quant_config.a1_scale
quant_config.a1_gscale
if quant_config.quant_dtype == "nvfp4"
else quant_config.a1_scale
)
return self._do_dispatch( return self._do_dispatch(
tokens=a1q, tokens=a1q,
...@@ -304,7 +291,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -304,7 +291,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts=num_experts, num_experts=num_experts,
a1_scale=a1_post_scale, a1_scale=a1_post_scale,
quant_config=quant_config, quant_config=quant_config,
defer_input_quant=defer_input_quant,
) )
def prepare( def prepare(
...@@ -316,7 +302,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -316,7 +302,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
receiver = self.prepare_async( receiver = self.prepare_async(
a1, a1,
...@@ -326,7 +311,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -326,7 +311,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
quant_config, quant_config,
defer_input_quant,
) )
return receiver() return receiver()
...@@ -433,4 +417,4 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -433,4 +417,4 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input, apply_router_weight_on_input,
weight_and_reduce_impl, weight_and_reduce_impl,
False, False,
) )
\ No newline at end of file
...@@ -242,14 +242,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -242,14 +242,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]: ) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
hidden_size = a1.size(1) hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, ( assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, (
f"Hidden Size {hidden_size} not in supported list of hidden sizes" f"Hidden Size {hidden_size} not in supported list of hidden sizes"
...@@ -351,13 +344,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -351,13 +344,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
hook, receiver = self.prepare_async( hook, receiver = self.prepare_async(
a1, a1,
topk_weights, topk_weights,
...@@ -446,4 +433,4 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -446,4 +433,4 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input, apply_router_weight_on_input,
weight_and_reduce_impl, weight_and_reduce_impl,
do_async=False, do_async=False,
) )
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.base_device_communicator import (
All2AllManagerBase,
)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""Base class for FlashInfer MoE prepare and finalize operations."""
def __init__(
self,
num_dispatchers: int = 1,
):
super().__init__()
self.num_dispatchers_ = num_dispatchers
self.all2all_manager = get_ep_group().device_communicator.all2all_manager
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return False
def _apply_router_weight_on_input(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""Apply router weight on input if needed."""
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1.mul_(topk_weights.to(a1.dtype))
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input
)
global_num_tokens_cpu = get_local_sizes()
top_k = topk_ids.size(1)
(self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = (
flashinfer_alltoall_dispatch(
self.all2all_manager,
global_num_tokens_cpu,
a1,
quant_config.a1_gscale,
topk_ids,
topk_weights,
top_k,
num_experts,
quant_config,
defer_input_quant=defer_input_quant,
)
)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
top_k = topk_ids.size(1)
token_count = output.shape[0]
fused_expert_output = flashinfer_alltoall_combine(
self.all2all_manager,
fused_expert_output,
top_k=top_k,
token_count=token_count,
alltoall_info=self.alltoall_info,
)
output.copy_(fused_expert_output)
def flashinfer_alltoall_dispatch(
all2all_manager: All2AllManagerBase,
global_num_tokens_cpu: list[int],
x: torch.Tensor,
gs: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
top_k: int,
num_experts: int,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
):
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
assert all2all_manager.ensure_alltoall_workspace_initialized(), (
"FlashInfer AllToAll workspace not available"
)
ep_rank = all2all_manager.rank
ep_size = all2all_manager.world_size
max_num_token = (
max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0]
)
orig_topk_weights_dtype = topk_weights.dtype
alltoall_info, topk_ids, topk_weights, _ = (
MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
topk_ids,
topk_weights,
None,
all2all_manager.prepare_workspace_tensor,
max_num_token,
ep_rank,
ep_size,
num_experts,
num_experts,
top_k,
)
)
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
if not defer_input_quant:
x, x_sf = moe_kernel_quantize_input(
x,
gs,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled=False,
)
x = MnnvlMoe.mnnvl_moe_alltoallv(
x,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
x_sf,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
# Swizzle after the A2A if nvfp4.
if quant_config.quant_dtype == "nvfp4":
if x_sf.element_size() == 1:
x_sf = x_sf.view(torch.uint8)
x_sf = nvfp4_block_scale_interleave(x_sf)
else:
# Block-scale path: pass activations through without quantization
x_sf = None
x = MnnvlMoe.mnnvl_moe_alltoallv(
x,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
return alltoall_info, topk_ids, topk_weights, x, x_sf
def flashinfer_alltoall_combine(
all2all_manager: All2AllManagerBase,
output: torch.Tensor,
top_k: int,
token_count: int,
alltoall_info,
):
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
assert all2all_manager.ensure_alltoall_workspace_initialized(), (
"FlashInfer AllToAll workspace not available"
)
return MnnvlMoe.mnnvl_moe_alltoallv_combine(
output,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank=all2all_manager.rank,
ep_size=all2all_manager.world_size,
top_k=top_k,
token_count=token_count,
)
...@@ -78,9 +78,16 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -78,9 +78,16 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# - skip input activation quantization (kernel applies scaling) # - skip input activation quantization (kernel applies scaling)
self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized
@property @staticmethod
def expects_unquantized_inputs(self) -> bool: def expects_unquantized_inputs(
return self.quant_config.use_fp8_w8a8 and self.quant_config.is_block_quantized moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
# NVFP4 TP kernels and FP8 block-quantized kernels apply
# input quantization inside FusedMoEPermuteExpertsUnpermute.
return (
quant_config.use_nvfp4_w4a4
and not moe_config.moe_parallel_config.use_all2all_kernels
) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized)
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
...@@ -138,8 +145,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -138,8 +145,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FLASHINFER_CUTLASS currently uses its down P/F, which does not # FLASHINFER_CUTLASS currently uses its down P/F, which does not
# work with SP. This will be removed in follow up after we get # work with SP. This will be removed in follow up after we get
# rid of the FlashInfer specific P/F function. # rid of the FlashInfer specific P/F function.
# TODO: the per-tensor fp8 kernels don't work with MNNVL FI A2As. return (
return not moe_parallel_config.is_sequence_parallel moe_parallel_config.dp_size == 1
or moe_parallel_config.dp_size == moe_parallel_config.ep_size
)
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
...@@ -186,9 +195,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -186,9 +195,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
""" """
workspace1 = (M, K) workspace1 = (M, K)
workspace2 = (0,) workspace2 = (0,)
# For NVFP4, the output is stored in a packed int8 format, # For TP, the quantization is fused with fused_moe call.
# so the actual hidden dim is 2x the size of K here. output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
# The workspace is determined by `aq`, since it comes after any # The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation. # potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape) return (workspace1, workspace2, output_shape)
...@@ -292,4 +300,4 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -292,4 +300,4 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
# No support for LoRA in flashinfer_cutlass_fused_moe. # No support for LoRA in flashinfer_cutlass_fused_moe.
# See TODOs in flashinfer functions runMoe and runMoeMinLantency. # See TODOs in flashinfer functions runMoe and runMoeMinLantency.
raise NotImplementedError("LoRA is not supported for flashinfer_cutlass_moe") raise NotImplementedError("LoRA is not supported for flashinfer_cutlass_moe")
\ No newline at end of file
...@@ -533,13 +533,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -533,13 +533,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert a1.dim() == 2 assert a1.dim() == 2
assert topk_ids.dim() == 2 assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0) assert topk_ids.size(0) == a1.size(0)
...@@ -1124,4 +1118,4 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1124,4 +1118,4 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
config=config, config=config,
per_act_token_quant=self.per_act_token_quant, per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape, block_shape=self.block_shape,
) )
\ No newline at end of file
...@@ -597,7 +597,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -597,7 +597,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not moe_parallel_config.use_fi_all2allv_kernels return True
@property @property
def quant_type_id(self) -> int: def quant_type_id(self) -> int:
......
...@@ -1108,6 +1108,7 @@ def dispatch_fused_moe_kernel( ...@@ -1108,6 +1108,7 @@ def dispatch_fused_moe_kernel(
num_experts=B.size(0), num_experts=B.size(0),
bit=4 if use_int4_w4a16 else 8, bit=4 if use_int4_w4a16 else 8,
) )
if use_moe_wna16_cuda: if use_moe_wna16_cuda:
invoke_fused_moe_wna16_cuda_kernel( invoke_fused_moe_wna16_cuda_kernel(
A, A,
...@@ -1167,6 +1168,7 @@ def dispatch_fused_moe_kernel( ...@@ -1167,6 +1168,7 @@ def dispatch_fused_moe_kernel(
B_bias, B_bias,
) )
@triton.jit @triton.jit
def compute_identity_kernel( def compute_identity_kernel(
top_k: int, top_k: int,
...@@ -2266,7 +2268,6 @@ def fused_experts_impl( ...@@ -2266,7 +2268,6 @@ def fused_experts_impl(
return out_hidden_states return out_hidden_states
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
......
...@@ -5,7 +5,6 @@ from abc import abstractmethod ...@@ -5,7 +5,6 @@ from abc import abstractmethod
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
...@@ -27,19 +26,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -27,19 +26,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
super().__init__() super().__init__()
self.moe: FusedMoEConfig = moe self.moe: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None self.moe_quant_config: FusedMoEQuantConfig | None = None
self.moe_mk: mk.FusedMoEModularKernel | None = None
@property
def supports_internal_mk(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
return self.moe_mk is not None
@property
def mk_owns_shared_expert(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
return self.moe_mk is not None and self.moe_mk.shared_experts is not None
@abstractmethod @abstractmethod
def create_weights( def create_weights(
...@@ -105,8 +91,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -105,8 +91,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@property @property
def topk_indices_dtype(self) -> torch.dtype | None: def topk_indices_dtype(self) -> torch.dtype | None:
if self.moe_mk is not None:
return self.moe_mk.prepare_finalize.topk_indices_dtype()
return None return None
@property @property
...@@ -142,4 +126,4 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -142,4 +126,4 @@ class FusedMoEMethodBase(QuantizeMethodBase):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
\ No newline at end of file
...@@ -30,11 +30,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -30,11 +30,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
): ):
super().__init__(old_quant_method.moe) super().__init__(old_quant_method.moe)
self.moe_quant_config = old_quant_method.moe_quant_config self.moe_quant_config = old_quant_method.moe_quant_config
self.moe_mk = experts self.fused_experts = experts
self.disable_expert_map = getattr( self.disable_expert_map = getattr(
old_quant_method, old_quant_method,
"disable_expert_map", "disable_expert_map",
not self.moe_mk.supports_expert_map(), not self.fused_experts.supports_expert_map(),
) )
self.old_quant_method = old_quant_method self.old_quant_method = old_quant_method
assert not self.old_quant_method.is_monolithic assert not self.old_quant_method.is_monolithic
...@@ -57,6 +57,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -57,6 +57,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
), ),
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
return self.fused_experts.prepare_finalize.topk_indices_dtype()
@property @property
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
return self.old_quant_method.supports_eplb return self.old_quant_method.supports_eplb
...@@ -92,8 +96,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -92,8 +96,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None return self.fused_experts(
return self.moe_mk(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -104,4 +107,4 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -104,4 +107,4 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map, expert_map=None if self.disable_expert_map else layer.expert_map,
) )
\ No newline at end of file
...@@ -757,6 +757,14 @@ class FusedMoE(CustomOp): ...@@ -757,6 +757,14 @@ class FusedMoE(CustomOp):
def use_mori_kernels(self): def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels return self.moe_parallel_config.use_mori_kernels
@property
def use_flashinfer_cutlass_kernels(self):
return (
self.moe_quant_config is not None
and self.moe_quant_config.quant_dtype == "nvfp4"
and self.moe_config_use_flashinfer_cutlass_kernels
)
@property @property
def use_marlin_kernels(self): def use_marlin_kernels(self):
return getattr(self.quant_method, "use_marlin", False) return getattr(self.quant_method, "use_marlin", False)
...@@ -767,7 +775,7 @@ class FusedMoE(CustomOp): ...@@ -767,7 +775,7 @@ class FusedMoE(CustomOp):
self.moe_parallel_config.use_pplx_kernels self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels or self.moe_parallel_config.use_mori_kernels
or self.moe_parallel_config.use_fi_all2allv_kernels or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
) and envs.VLLM_ENABLE_MOE_DP_CHUNK ) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@property @property
...@@ -1189,8 +1197,6 @@ class FusedMoE(CustomOp): ...@@ -1189,8 +1197,6 @@ class FusedMoE(CustomOp):
# dimension intermediate_size_per_partition is used. # dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id]
is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type: if is_gguf_weight_type:
...@@ -1567,7 +1573,7 @@ class FusedMoE(CustomOp): ...@@ -1567,7 +1573,7 @@ class FusedMoE(CustomOp):
assert self.quant_method is not None assert self.quant_method is not None
return ( return (
isinstance(self.quant_method, FusedMoEModularMethod) isinstance(self.quant_method, FusedMoEModularMethod)
and self.quant_method.moe_mk.output_is_reduced() # type: ignore[union-attr] and self.quant_method.fused_experts.output_is_reduced()
) )
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
...@@ -1706,36 +1712,6 @@ class FusedMoE(CustomOp): ...@@ -1706,36 +1712,6 @@ class FusedMoE(CustomOp):
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True)
zero_expert_result = None
if self.zero_expert_num > 0 and self.zero_expert_type is not None:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=staged_hidden_states,
router_logits=staged_router_logits,
use_grouped_topk=self.use_grouped_topk,
top_k=self.top_k,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
indices_type=self.quant_method.topk_indices_dtype,
enable_eplb=self.enable_eplb,
expert_map=self.expert_map,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count)
# Compute zero_expert_result
zero_expert_result = zero_experts_compute_triton(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=self.global_num_experts,
zero_expert_type=self.zero_expert_type,
hidden_states=staged_hidden_states,
)
# Matrix multiply. # Matrix multiply.
if self.quant_method.is_monolithic: if self.quant_method.is_monolithic:
...@@ -1831,7 +1807,7 @@ class FusedMoE(CustomOp): ...@@ -1831,7 +1807,7 @@ class FusedMoE(CustomOp):
self.ensure_dp_chunking_init() self.ensure_dp_chunking_init()
has_separate_shared_experts = ( has_separate_shared_experts = (
not self.quant_method.mk_owns_shared_expert not isinstance(self.quant_method, FusedMoEModularMethod)
and self.shared_experts is not None and self.shared_experts is not None
) )
...@@ -1857,8 +1833,8 @@ class FusedMoE(CustomOp): ...@@ -1857,8 +1833,8 @@ class FusedMoE(CustomOp):
# NOTE(rob): once we finish migrating all the quant methods to use # NOTE(rob): once we finish migrating all the quant methods to use
# MKs, we can remove the naive dispatch/combine path from here. # MKs, we can remove the naive dispatch/combine path from here.
do_naive_dispatch_combine = ( do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance(
self.dp_size > 1 and not self.quant_method.supports_internal_mk self.quant_method, FusedMoEModularMethod
) )
ctx = get_forward_context() ctx = get_forward_context()
...@@ -1886,7 +1862,7 @@ class FusedMoE(CustomOp): ...@@ -1886,7 +1862,7 @@ class FusedMoE(CustomOp):
else: else:
hidden_states_to_dispatch = hidden_states hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch_router_logits( dispatch_res = get_ep_group().dispatch(
hidden_states_to_dispatch, hidden_states_to_dispatch,
router_logits, router_logits,
self.is_sequence_parallel, self.is_sequence_parallel,
...@@ -1948,7 +1924,6 @@ class FusedMoE(CustomOp): ...@@ -1948,7 +1924,6 @@ class FusedMoE(CustomOp):
if self.capture is not None: if self.capture is not None:
self.capture(topk_ids) self.capture(topk_ids)
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=x, # The type signture of this is wrong due to the hack. x=x, # The type signture of this is wrong due to the hack.
...@@ -1988,7 +1963,6 @@ class FusedMoE(CustomOp): ...@@ -1988,7 +1963,6 @@ class FusedMoE(CustomOp):
dim=0, dim=0,
) )
return states return states
if self.shared_experts is not None: if self.shared_experts is not None:
return ( return (
final_hidden_states[0], final_hidden_states[0],
...@@ -2043,7 +2017,6 @@ class FusedMoE(CustomOp): ...@@ -2043,7 +2017,6 @@ class FusedMoE(CustomOp):
] ]
] ]
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = ( s = (
f"global_num_experts={self.global_num_experts}, " f"global_num_experts={self.global_num_experts}, "
...@@ -2090,8 +2063,6 @@ def moe_forward_fake( ...@@ -2090,8 +2063,6 @@ def moe_forward_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
......
...@@ -180,7 +180,6 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -180,7 +180,6 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> PrepareResultType: ) -> PrepareResultType:
""" """
Perform any quantization (and/or) dispatching needed for this kernel. Perform any quantization (and/or) dispatching needed for this kernel.
...@@ -193,9 +192,6 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -193,9 +192,6 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the - apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching. activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts. - quant_config: Quantization info provided by the fused experts.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
in cases where the compute kernel expects unquantized inputs
Returns a tuple of: Returns a tuple of:
- quantized + dispatched a. - quantized + dispatched a.
...@@ -224,7 +220,6 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -224,7 +220,6 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> tuple[Callable, ReceiverType] | ReceiverType: ) -> tuple[Callable, ReceiverType] | ReceiverType:
""" """
Perform any quantization (and/or) dispatching needed for this kernel Perform any quantization (and/or) dispatching needed for this kernel
...@@ -240,9 +235,6 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -240,9 +235,6 @@ class FusedMoEPrepareAndFinalize(ABC):
space to the local expert space of the expert parallel shard. space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the - apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching. activations, before quantization + dispatching.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
in cases where the compute kernel expects unquantized inputs
Returns a callback or a hook callback pair that when invoked waits for Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as results from other workers and has the same return signature as
...@@ -415,8 +407,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -415,8 +407,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
@property @staticmethod
def expects_unquantized_inputs(self) -> bool: def expects_unquantized_inputs(
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
""" """
Whether or not the PrepareFinalize should defer input quantization Whether or not the PrepareFinalize should defer input quantization
in the prepare step. If True, then the Experts kernel will in the prepare step. If True, then the Experts kernel will
...@@ -1075,7 +1069,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1075,7 +1069,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
) )
else: else:
# Overlap shared expert compute with all2all dispatch. # Overlap shared expert compute with all2all dispatch.
...@@ -1088,7 +1081,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1088,7 +1081,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
) )
# TODO(lucas): refactor this in the alternative schedules followup # TODO(lucas): refactor this in the alternative schedules followup
...@@ -1139,7 +1131,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1139,7 +1131,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
expert_tokens_meta: ExpertTokensMetadata | None, expert_tokens_meta: ExpertTokensMetadata | None,
use_nn_moe: bool | None = False,
) -> torch.Tensor: ) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size( _, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids a1q, w1, w2, topk_ids
...@@ -1215,7 +1206,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1215,7 +1206,6 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace2=workspace2, workspace2=workspace2,
expert_tokens_meta=c_expert_tokens_meta, expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe,
) )
return fused_out return fused_out
...@@ -1299,7 +1289,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1299,7 +1289,6 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets This function computes a Mixture of Experts (MoE) layer using two sets
...@@ -1361,7 +1350,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1361,7 +1350,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
use_nn_moe=use_nn_moe,
) )
return self._finalize( return self._finalize(
...@@ -1371,4 +1359,4 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1371,4 +1359,4 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights, topk_weights,
topk_ids, topk_ids,
apply_router_weight_on_input, apply_router_weight_on_input,
) )
\ No newline at end of file
...@@ -3,6 +3,70 @@ ...@@ -3,6 +3,70 @@
import torch import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: torch.Tensor | None,
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: torch.Tensor | None,
block_m: int,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.size(1)
tokens_in_chunk = curr_hidden_states.size(0)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True
)
inv_perm: torch.Tensor | None = None
num_tokens = top_k_num * tokens_in_chunk
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num)
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, inv_perm)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: torch.Tensor | None,
topk_weight: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.size()
K = curr_hidden.size(-1)
if inv_perm is not None:
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
if not apply_router_weight_on_input:
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def moe_permute( def moe_permute(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -162,4 +226,4 @@ def moe_unpermute( ...@@ -162,4 +226,4 @@ def moe_unpermute(
def moe_permute_unpermute_supported(): def moe_permute_unpermute_supported():
return torch.ops._moe_C.moe_permute_unpermute_supported() return torch.ops._moe_C.moe_permute_unpermute_supported()
\ No newline at end of file
...@@ -58,7 +58,6 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -58,7 +58,6 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
""" """
Returns a tuple of: Returns a tuple of:
...@@ -70,11 +69,6 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -70,11 +69,6 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
- Optional dispatched expert topk IDs - Optional dispatched expert topk IDs
- Optional dispatched expert topk weight - Optional dispatched expert topk weight
""" """
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"mori does not support apply_router_weight_on_input=True now." "mori does not support apply_router_weight_on_input=True now."
) )
...@@ -124,4 +118,4 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -124,4 +118,4 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
None, None,
topk_ids, topk_ids,
)[0] )[0]
output.copy_(result[:num_token]) output.copy_(result[:num_token])
\ No newline at end of file
...@@ -8,9 +8,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -8,9 +8,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -20,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -20,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_fp8, is_supported_config_trtllm_fp8,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend, FlashinferMoeBackend,
get_flashinfer_moe_backend, get_flashinfer_moe_backend,
...@@ -35,7 +35,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( ...@@ -35,7 +35,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
) )
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -331,16 +330,9 @@ def select_fp8_moe_backend( ...@@ -331,16 +330,9 @@ def select_fp8_moe_backend(
else: else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local") logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
# TODO(rob): per discussion with TPU team, we need a way to register raise NotImplementedError(
# MoE backends by OOT plugins, rather than having an explicit list "No FP8 MoE backend supports the deployment configuration."
# of AVAILBLE_BACKENDS. Enabling returning `Fp8MoeBackend.NONE` is )
# a temporary measure until these register APIs are complete.
if current_platform.is_cuda() or current_platform.is_rocm():
raise NotImplementedError(
"No FP8 MoE backend supports the deployment configuration."
)
return Fp8MoeBackend.NONE, None
def convert_to_fp8_moe_kernel_format( def convert_to_fp8_moe_kernel_format(
...@@ -465,55 +457,71 @@ def make_fp8_moe_quant_config( ...@@ -465,55 +457,71 @@ def make_fp8_moe_quant_config(
) )
def make_fp8_moe_kernel( def make_fp8_moe_kernel_for_mkm(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
fp8_backend: Fp8MoeBackend, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPermuteExpertsUnpermute:
shared_experts: torch.nn.Module | None = None,
) -> tuple[mk.FusedMoEModularKernel, bool]:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank() max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None assert max_num_tokens_per_rank is not None
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=moe_quant_config, quant_config=quant_config,
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
) )
else: else:
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=moe_quant_config, quant_config=quant_config,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
fp8_backend: Fp8MoeBackend,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> tuple[mk.FusedMoEModularKernel, bool]:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"Fp8 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
) )
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# NOTE(rob): we only want the mk to control the shared_expert # NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in # if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class. # the new MoE runner class.
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=( shared_experts=None,
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config, moe_parallel_config=moe_config.moe_parallel_config,
) )
# TODO(rob): update inplace logic to be part of the kernel. # TODO(rob): update inplace logic to be part of the kernel.
inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
return kernel, inplace return kernel, inplace
\ No newline at end of file
...@@ -7,9 +7,6 @@ import torch ...@@ -7,9 +7,6 @@ import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -17,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -17,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config, nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config, nvfp4_w4a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_supported_config_trtllm, is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass, prepare_nvfp4_moe_layer_for_fi_or_cutlass,
...@@ -391,53 +391,69 @@ def make_nvfp4_moe_quant_config( ...@@ -391,53 +391,69 @@ def make_nvfp4_moe_quant_config(
) )
def make_nvfp4_moe_kernel( def make_nvfp4_moe_kernel_for_mkm(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
shared_experts: torch.nn.Module | None = None, ) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEModularKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank() max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None assert max_num_tokens_per_rank is not None
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=moe_quant_config, quant_config=quant_config,
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
) )
else: else:
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=moe_quant_config, quant_config=quant_config,
) )
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> mk.FusedMoEModularKernel:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"NvFP4 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# NOTE(rob): we only want the mk to control the shared_expert # NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in # if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class. # the new MoE runner class.
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=( shared_experts=None,
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config, moe_parallel_config=moe_config.moe_parallel_config,
) )
# TODO(rob): update inplace logic to be part of the kernel. # TODO(rob): update inplace logic to be part of the kernel.
return kernel return kernel
\ No newline at end of file
...@@ -106,14 +106,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -106,14 +106,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]: ) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
num_tokens = a1.size(0) # M num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K hidden_dim = a1.size(-1) # K
...@@ -281,7 +274,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -281,7 +274,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
hook, receiver = self.prepare_async( hook, receiver = self.prepare_async(
a1, a1,
...@@ -291,7 +283,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -291,7 +283,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
quant_config, quant_config,
defer_input_quant=defer_input_quant,
) )
hook() hook()
return receiver() return receiver()
...@@ -368,4 +359,4 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -368,4 +359,4 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input, apply_router_weight_on_input,
weight_and_reduce_impl, weight_and_reduce_impl,
) )
receiver() receiver()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment