Commit bc387d5a authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1 (fused_moe)

parent 899a2db4
......@@ -7,27 +7,17 @@ import torch
from vllm.distributed import (
get_ep_group,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import (
FlashInferA2APrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNaiveEP,
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
logger = init_logger(__name__)
if current_platform.is_cuda_alike():
if has_pplx():
from .pplx_prepare_finalize import (
......@@ -80,46 +70,20 @@ def maybe_make_prepare_finalize(
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
allow_new_interface: bool = False,
) -> FusedMoEPrepareAndFinalize | None:
# NOTE(rob): we are migrating each quant_method to hold the MK
# in all cases. The allow_new_interface=False flag allow us to fall
# back to the old method for methods that have not yet been migrated.
#
# In old method:
# * maybe_init_modular_kernel() calls this function. If we are
# using no Dp/Ep or naive all2all, we return None this function
# returns None and no ModularKernelMethod is created. If non-naive
# all2all is used, this returns a PrepareAndFinalize object and
# a ModularKernelMethod is created.
# In new method:
# * maybe_make_prepare_finalize() is called from the oracle. We
# always return a PrepareAndFinalize object and the quant method
# holds the ModularKernel.
if not moe.moe_parallel_config.use_all2all_kernels:
if not allow_new_interface:
return None
# For DP/TP case, fall back to naive P/F.
if moe.moe_parallel_config.dp_size > 1:
logger.info_once(
"Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine."
)
return MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=(
get_ep_group().device_communicator.all2all_manager.world_size
),
)
else:
return MoEPrepareAndFinalizeNoEP()
return None
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
# TODO(rob): update this as part of the MoE refactor.
assert not moe.use_flashinfer_cutlass_kernels, (
"Must be created in modelopt.py or fp8.py"
)
if moe.use_pplx_kernels:
assert quant_config is not None
......@@ -239,16 +203,4 @@ def maybe_make_prepare_finalize(
use_fp8_dispatch=use_fp8_dispatch,
)
elif moe.use_fi_all2allv_kernels:
assert quant_config is not None
prepare_finalize = FlashInferA2APrepareAndFinalize(
num_dispatchers=all2all_manager.world_size,
)
elif moe.use_naive_all2all_kernels and allow_new_interface:
prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
num_dispatchers=all2all_manager.world_size,
)
return prepare_finalize
return prepare_finalize
\ No newline at end of file
......@@ -20,6 +20,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import cdiv
......@@ -892,13 +893,7 @@ class FusedMoEParallelConfig:
@property
def use_deepep_ll_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@property
def use_fi_all2allv_kernels(self):
return (
self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv"
)
@property
def use_batched_activation_format(self):
return self.use_deepep_ll_kernels or self.use_pplx_kernels
......@@ -1030,7 +1025,6 @@ class FusedMoEParallelConfig:
ep_rank=0,
use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend,
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb,
)
# DP + EP / TP + EP / DP + TP + EP
......@@ -1050,7 +1044,6 @@ class FusedMoEParallelConfig:
ep_rank=ep_rank,
use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend,
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb,
)
......@@ -1069,7 +1062,6 @@ class FusedMoEParallelConfig:
use_ep=False,
all2all_backend="naive",
enable_eplb=False,
is_sequence_parallel=False,
)
......@@ -1164,9 +1156,12 @@ class FusedMoEConfig:
return self.moe_parallel_config.use_mori_kernels
@property
def use_fi_all2allv_kernels(self):
return self.moe_parallel_config.use_fi_all2allv_kernels
@property
def use_naive_all2all_kernels(self):
return self.moe_parallel_config.use_naive_all2all_kernels
def use_flashinfer_cutlass_kernels(self):
"""
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
"""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput"
)
......@@ -103,14 +103,7 @@ def run_cutlass_moe_fp8(
or a2_scale.size(0) == a1q.shape[0]
), "Intermediate scale shape mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
# NOTE(rob): the expert_map is used for the STANDARD case and
# the batched format is used by the BATCHED case.
# TODO(rob): update the MK interface to only pass the expert_map
# during the STANDARD case to make this clearer across all kernels.
if use_batched_format:
assert expert_num_tokens is not None
else:
if expert_map is not None:
assert expert_num_tokens is None
# We have two modes: batched experts and non-batched experts.
......@@ -386,10 +379,7 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# needed for STANDARD activation format kernels in DP/EP mode.
# Note that the BATCHED activation format does not use
# the expert map for identifying experts.
return not (
moe_parallel_config.use_fi_all2allv_kernels
or moe_parallel_config.use_deepep_ht_kernels
)
return not moe_parallel_config.use_all2all_kernels
def supports_chunking(self) -> bool:
return True
......@@ -651,8 +641,10 @@ def run_cutlass_moe_fp4(
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
@property
def expects_unquantized_inputs(self) -> bool:
@staticmethod
def expects_unquantized_inputs(
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
return True
@staticmethod
......@@ -1177,4 +1169,4 @@ def cutlass_moe_w4a8_fp8(
global_num_experts=num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
)
\ No newline at end of file
......@@ -148,8 +148,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# NOTE(rob): discovered an IMA with this combination. Needs investigation.
return not moe_parallel_config.use_fi_all2allv_kernels
return True
def supports_chunking(self) -> bool:
return True
......@@ -308,4 +307,4 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
inv_perm=inv_perm,
expert_map=expert_map,
output=output,
)
)
\ No newline at end of file
......@@ -103,7 +103,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts: int,
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> Callable:
has_scales = token_scales is not None
......@@ -175,7 +174,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights,
a1_scale,
quant_config,
defer_input_quant=defer_input_quant,
)
def _receiver(
......@@ -189,7 +187,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights: torch.Tensor | None,
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> mk.PrepareResultType:
if event.event is not None:
event.current_stream_wait()
......@@ -224,15 +221,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_num_tokens_per_expert_list, device=expert_x.device
)
# * For non-block quant, dispatch in b16 and quantize now as
# DeepEP kernels only support dispatching block scales.
# * For expert kernels that require unquantized inputs,
# defer quantization to FusedMoEExpertsPermuteUnpermute.
if not quant_config.is_block_quantized and not defer_input_quant:
# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if not quant_config.is_block_quantized:
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
# TODO: support per_act_token_quant,
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
......@@ -261,7 +257,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.ReceiverType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
......@@ -271,12 +266,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
a1 = a1 * topk_weights.to(a1.dtype)
# * DeepEP only supports fp8 block scales so quantize
# before the dispatch for these models.
# * For all other quantization, dispatch after.
# * For expert kernels that require unquantized inputs,
# defer quantization to FusedMoEExpertsPermuteUnpermute.
if quant_config.is_block_quantized and not defer_input_quant:
if quant_config.is_block_quantized:
# Quant and Dispatch
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_scale,
......@@ -290,11 +281,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else:
a1q = a1
a1q_scale = None
a1_post_scale = (
quant_config.a1_gscale
if quant_config.quant_dtype == "nvfp4"
else quant_config.a1_scale
)
a1_post_scale = quant_config.a1_scale
return self._do_dispatch(
tokens=a1q,
......@@ -304,7 +291,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts=num_experts,
a1_scale=a1_post_scale,
quant_config=quant_config,
defer_input_quant=defer_input_quant,
)
def prepare(
......@@ -316,7 +302,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
receiver = self.prepare_async(
a1,
......@@ -326,7 +311,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map,
apply_router_weight_on_input,
quant_config,
defer_input_quant,
)
return receiver()
......@@ -433,4 +417,4 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input,
weight_and_reduce_impl,
False,
)
)
\ No newline at end of file
......@@ -242,14 +242,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, (
f"Hidden Size {hidden_size} not in supported list of hidden sizes"
......@@ -351,13 +344,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
hook, receiver = self.prepare_async(
a1,
topk_weights,
......@@ -446,4 +433,4 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input,
weight_and_reduce_impl,
do_async=False,
)
)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.base_device_communicator import (
All2AllManagerBase,
)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""Base class for FlashInfer MoE prepare and finalize operations."""
def __init__(
self,
num_dispatchers: int = 1,
):
super().__init__()
self.num_dispatchers_ = num_dispatchers
self.all2all_manager = get_ep_group().device_communicator.all2all_manager
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return False
def _apply_router_weight_on_input(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""Apply router weight on input if needed."""
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1.mul_(topk_weights.to(a1.dtype))
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input
)
global_num_tokens_cpu = get_local_sizes()
top_k = topk_ids.size(1)
(self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = (
flashinfer_alltoall_dispatch(
self.all2all_manager,
global_num_tokens_cpu,
a1,
quant_config.a1_gscale,
topk_ids,
topk_weights,
top_k,
num_experts,
quant_config,
defer_input_quant=defer_input_quant,
)
)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
top_k = topk_ids.size(1)
token_count = output.shape[0]
fused_expert_output = flashinfer_alltoall_combine(
self.all2all_manager,
fused_expert_output,
top_k=top_k,
token_count=token_count,
alltoall_info=self.alltoall_info,
)
output.copy_(fused_expert_output)
def flashinfer_alltoall_dispatch(
all2all_manager: All2AllManagerBase,
global_num_tokens_cpu: list[int],
x: torch.Tensor,
gs: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
top_k: int,
num_experts: int,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
):
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
assert all2all_manager.ensure_alltoall_workspace_initialized(), (
"FlashInfer AllToAll workspace not available"
)
ep_rank = all2all_manager.rank
ep_size = all2all_manager.world_size
max_num_token = (
max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0]
)
orig_topk_weights_dtype = topk_weights.dtype
alltoall_info, topk_ids, topk_weights, _ = (
MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
topk_ids,
topk_weights,
None,
all2all_manager.prepare_workspace_tensor,
max_num_token,
ep_rank,
ep_size,
num_experts,
num_experts,
top_k,
)
)
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
if not defer_input_quant:
x, x_sf = moe_kernel_quantize_input(
x,
gs,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled=False,
)
x = MnnvlMoe.mnnvl_moe_alltoallv(
x,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
x_sf,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
# Swizzle after the A2A if nvfp4.
if quant_config.quant_dtype == "nvfp4":
if x_sf.element_size() == 1:
x_sf = x_sf.view(torch.uint8)
x_sf = nvfp4_block_scale_interleave(x_sf)
else:
# Block-scale path: pass activations through without quantization
x_sf = None
x = MnnvlMoe.mnnvl_moe_alltoallv(
x,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
return alltoall_info, topk_ids, topk_weights, x, x_sf
def flashinfer_alltoall_combine(
all2all_manager: All2AllManagerBase,
output: torch.Tensor,
top_k: int,
token_count: int,
alltoall_info,
):
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
assert all2all_manager.ensure_alltoall_workspace_initialized(), (
"FlashInfer AllToAll workspace not available"
)
return MnnvlMoe.mnnvl_moe_alltoallv_combine(
output,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank=all2all_manager.rank,
ep_size=all2all_manager.world_size,
top_k=top_k,
token_count=token_count,
)
......@@ -78,9 +78,16 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# - skip input activation quantization (kernel applies scaling)
self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized
@property
def expects_unquantized_inputs(self) -> bool:
return self.quant_config.use_fp8_w8a8 and self.quant_config.is_block_quantized
@staticmethod
def expects_unquantized_inputs(
moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
# NVFP4 TP kernels and FP8 block-quantized kernels apply
# input quantization inside FusedMoEPermuteExpertsUnpermute.
return (
quant_config.use_nvfp4_w4a4
and not moe_config.moe_parallel_config.use_all2all_kernels
) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized)
@staticmethod
def _supports_current_device() -> bool:
......@@ -138,8 +145,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FLASHINFER_CUTLASS currently uses its down P/F, which does not
# work with SP. This will be removed in follow up after we get
# rid of the FlashInfer specific P/F function.
# TODO: the per-tensor fp8 kernels don't work with MNNVL FI A2As.
return not moe_parallel_config.is_sequence_parallel
return (
moe_parallel_config.dp_size == 1
or moe_parallel_config.dp_size == moe_parallel_config.ep_size
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
......@@ -186,9 +195,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
workspace1 = (M, K)
workspace2 = (0,)
# For NVFP4, the output is stored in a packed int8 format,
# so the actual hidden dim is 2x the size of K here.
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
# For TP, the quantization is fused with fused_moe call.
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
# The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape)
......@@ -292,4 +300,4 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
# No support for LoRA in flashinfer_cutlass_fused_moe.
# See TODOs in flashinfer functions runMoe and runMoeMinLantency.
raise NotImplementedError("LoRA is not supported for flashinfer_cutlass_moe")
raise NotImplementedError("LoRA is not supported for flashinfer_cutlass_moe")
\ No newline at end of file
......@@ -533,13 +533,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert a1.dim() == 2
assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0)
......@@ -1124,4 +1118,4 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
config=config,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape,
)
)
\ No newline at end of file
......@@ -597,7 +597,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not moe_parallel_config.use_fi_all2allv_kernels
return True
@property
def quant_type_id(self) -> int:
......
......@@ -1108,6 +1108,7 @@ def dispatch_fused_moe_kernel(
num_experts=B.size(0),
bit=4 if use_int4_w4a16 else 8,
)
if use_moe_wna16_cuda:
invoke_fused_moe_wna16_cuda_kernel(
A,
......@@ -1167,6 +1168,7 @@ def dispatch_fused_moe_kernel(
B_bias,
)
@triton.jit
def compute_identity_kernel(
top_k: int,
......@@ -2266,7 +2268,6 @@ def fused_experts_impl(
return out_hidden_states
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
......
......@@ -5,7 +5,6 @@ from abc import abstractmethod
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
......@@ -27,19 +26,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
super().__init__()
self.moe: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None
self.moe_mk: mk.FusedMoEModularKernel | None = None
@property
def supports_internal_mk(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
return self.moe_mk is not None
@property
def mk_owns_shared_expert(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
return self.moe_mk is not None and self.moe_mk.shared_experts is not None
@abstractmethod
def create_weights(
......@@ -105,8 +91,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.moe_mk is not None:
return self.moe_mk.prepare_finalize.topk_indices_dtype()
return None
@property
......@@ -142,4 +126,4 @@ class FusedMoEMethodBase(QuantizeMethodBase):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
raise NotImplementedError
\ No newline at end of file
......@@ -30,11 +30,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
):
super().__init__(old_quant_method.moe)
self.moe_quant_config = old_quant_method.moe_quant_config
self.moe_mk = experts
self.fused_experts = experts
self.disable_expert_map = getattr(
old_quant_method,
"disable_expert_map",
not self.moe_mk.supports_expert_map(),
not self.fused_experts.supports_expert_map(),
)
self.old_quant_method = old_quant_method
assert not self.old_quant_method.is_monolithic
......@@ -57,6 +57,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
),
)
@property
def topk_indices_dtype(self) -> torch.dtype | None:
return self.fused_experts.prepare_finalize.topk_indices_dtype()
@property
def supports_eplb(self) -> bool:
return self.old_quant_method.supports_eplb
......@@ -92,8 +96,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
return self.moe_mk(
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......@@ -104,4 +107,4 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map,
)
)
\ No newline at end of file
......@@ -757,6 +757,14 @@ class FusedMoE(CustomOp):
def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels
@property
def use_flashinfer_cutlass_kernels(self):
return (
self.moe_quant_config is not None
and self.moe_quant_config.quant_dtype == "nvfp4"
and self.moe_config_use_flashinfer_cutlass_kernels
)
@property
def use_marlin_kernels(self):
return getattr(self.quant_method, "use_marlin", False)
......@@ -767,7 +775,7 @@ class FusedMoE(CustomOp):
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels
or self.moe_parallel_config.use_fi_all2allv_kernels
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@property
......@@ -1189,8 +1197,6 @@ class FusedMoE(CustomOp):
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id]
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
......@@ -1567,7 +1573,7 @@ class FusedMoE(CustomOp):
assert self.quant_method is not None
return (
isinstance(self.quant_method, FusedMoEModularMethod)
and self.quant_method.moe_mk.output_is_reduced() # type: ignore[union-attr]
and self.quant_method.fused_experts.output_is_reduced()
)
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
......@@ -1706,36 +1712,6 @@ class FusedMoE(CustomOp):
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)
zero_expert_result = None
if self.zero_expert_num > 0 and self.zero_expert_type is not None:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=staged_hidden_states,
router_logits=staged_router_logits,
use_grouped_topk=self.use_grouped_topk,
top_k=self.top_k,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
indices_type=self.quant_method.topk_indices_dtype,
enable_eplb=self.enable_eplb,
expert_map=self.expert_map,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count)
# Compute zero_expert_result
zero_expert_result = zero_experts_compute_triton(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=self.global_num_experts,
zero_expert_type=self.zero_expert_type,
hidden_states=staged_hidden_states,
)
# Matrix multiply.
if self.quant_method.is_monolithic:
......@@ -1831,7 +1807,7 @@ class FusedMoE(CustomOp):
self.ensure_dp_chunking_init()
has_separate_shared_experts = (
not self.quant_method.mk_owns_shared_expert
not isinstance(self.quant_method, FusedMoEModularMethod)
and self.shared_experts is not None
)
......@@ -1857,8 +1833,8 @@ class FusedMoE(CustomOp):
# NOTE(rob): once we finish migrating all the quant methods to use
# MKs, we can remove the naive dispatch/combine path from here.
do_naive_dispatch_combine = (
self.dp_size > 1 and not self.quant_method.supports_internal_mk
do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance(
self.quant_method, FusedMoEModularMethod
)
ctx = get_forward_context()
......@@ -1886,7 +1862,7 @@ class FusedMoE(CustomOp):
else:
hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch_router_logits(
dispatch_res = get_ep_group().dispatch(
hidden_states_to_dispatch,
router_logits,
self.is_sequence_parallel,
......@@ -1948,7 +1924,6 @@ class FusedMoE(CustomOp):
if self.capture is not None:
self.capture(topk_ids)
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
......@@ -1988,7 +1963,6 @@ class FusedMoE(CustomOp):
dim=0,
)
return states
if self.shared_experts is not None:
return (
final_hidden_states[0],
......@@ -2043,7 +2017,6 @@ class FusedMoE(CustomOp):
]
]
def extra_repr(self) -> str:
s = (
f"global_num_experts={self.global_num_experts}, "
......@@ -2090,8 +2063,6 @@ def moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
......
......@@ -180,7 +180,6 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> PrepareResultType:
"""
Perform any quantization (and/or) dispatching needed for this kernel.
......@@ -193,9 +192,6 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
in cases where the compute kernel expects unquantized inputs
Returns a tuple of:
- quantized + dispatched a.
......@@ -224,7 +220,6 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> tuple[Callable, ReceiverType] | ReceiverType:
"""
Perform any quantization (and/or) dispatching needed for this kernel
......@@ -240,9 +235,6 @@ class FusedMoEPrepareAndFinalize(ABC):
space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
in cases where the compute kernel expects unquantized inputs
Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as
......@@ -415,8 +407,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@property
def expects_unquantized_inputs(self) -> bool:
@staticmethod
def expects_unquantized_inputs(
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
"""
Whether or not the PrepareFinalize should defer input quantization
in the prepare step. If True, then the Experts kernel will
......@@ -1075,7 +1069,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
)
else:
# Overlap shared expert compute with all2all dispatch.
......@@ -1088,7 +1081,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
)
# TODO(lucas): refactor this in the alternative schedules followup
......@@ -1139,7 +1131,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
expert_tokens_meta: ExpertTokensMetadata | None,
use_nn_moe: bool | None = False,
) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids
......@@ -1215,7 +1206,6 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace2=workspace2,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe,
)
return fused_out
......@@ -1299,7 +1289,6 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
......@@ -1361,7 +1350,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_tokens_meta=expert_tokens_meta,
use_nn_moe=use_nn_moe,
)
return self._finalize(
......@@ -1371,4 +1359,4 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights,
topk_ids,
apply_router_weight_on_input,
)
)
\ No newline at end of file
......@@ -3,6 +3,70 @@
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: torch.Tensor | None,
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: torch.Tensor | None,
block_m: int,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.size(1)
tokens_in_chunk = curr_hidden_states.size(0)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True
)
inv_perm: torch.Tensor | None = None
num_tokens = top_k_num * tokens_in_chunk
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num)
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, inv_perm)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: torch.Tensor | None,
topk_weight: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.size()
K = curr_hidden.size(-1)
if inv_perm is not None:
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
if not apply_router_weight_on_input:
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def moe_permute(
hidden_states: torch.Tensor,
......@@ -162,4 +226,4 @@ def moe_unpermute(
def moe_permute_unpermute_supported():
return torch.ops._moe_C.moe_permute_unpermute_supported()
return torch.ops._moe_C.moe_permute_unpermute_supported()
\ No newline at end of file
......@@ -58,7 +58,6 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
"""
Returns a tuple of:
......@@ -70,11 +69,6 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
"""
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert not apply_router_weight_on_input, (
"mori does not support apply_router_weight_on_input=True now."
)
......@@ -124,4 +118,4 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
None,
topk_ids,
)[0]
output.copy_(result[:num_token])
output.copy_(result[:num_token])
\ No newline at end of file
......@@ -8,9 +8,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
......@@ -20,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_fp8,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
......@@ -35,7 +35,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
......@@ -331,16 +330,9 @@ def select_fp8_moe_backend(
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
# TODO(rob): per discussion with TPU team, we need a way to register
# MoE backends by OOT plugins, rather than having an explicit list
# of AVAILBLE_BACKENDS. Enabling returning `Fp8MoeBackend.NONE` is
# a temporary measure until these register APIs are complete.
if current_platform.is_cuda() or current_platform.is_rocm():
raise NotImplementedError(
"No FP8 MoE backend supports the deployment configuration."
)
return Fp8MoeBackend.NONE, None
raise NotImplementedError(
"No FP8 MoE backend supports the deployment configuration."
)
def convert_to_fp8_moe_kernel_format(
......@@ -465,55 +457,71 @@ def make_fp8_moe_quant_config(
)
def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
def make_fp8_moe_kernel_for_mkm(
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
fp8_backend: Fp8MoeBackend,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> tuple[mk.FusedMoEModularKernel, bool]:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> mk.FusedMoEPermuteExpertsUnpermute:
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
max_num_tokens=max_num_tokens,
quant_config=quant_config,
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
quant_config=quant_config,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
fp8_backend: Fp8MoeBackend,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> tuple[mk.FusedMoEModularKernel, bool]:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"Fp8 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class.
kernel = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
else None
),
shared_experts=None,
moe_parallel_config=moe_config.moe_parallel_config,
)
# TODO(rob): update inplace logic to be part of the kernel.
inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
return kernel, inplace
return kernel, inplace
\ No newline at end of file
......@@ -7,9 +7,6 @@ import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
......@@ -17,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
......@@ -391,53 +391,69 @@ def make_nvfp4_moe_quant_config(
)
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
def make_nvfp4_moe_kernel_for_mkm(
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> mk.FusedMoEPermuteExpertsUnpermute:
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
max_num_tokens=max_num_tokens,
quant_config=quant_config,
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
quant_config=quant_config,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> mk.FusedMoEModularKernel:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"NvFP4 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class.
kernel = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
else None
),
shared_experts=None,
moe_parallel_config=moe_config.moe_parallel_config,
)
# TODO(rob): update inplace logic to be part of the kernel.
return kernel
return kernel
\ No newline at end of file
......@@ -106,14 +106,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
......@@ -281,7 +274,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
......@@ -291,7 +283,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map,
apply_router_weight_on_input,
quant_config,
defer_input_quant=defer_input_quant,
)
hook()
return receiver()
......@@ -368,4 +359,4 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input,
weight_and_reduce_impl,
)
receiver()
receiver()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment