Unverified Commit 5a93b916 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Integrate Naive Prepare Finalize into MK (#32567)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarAmir Klein <203507526+amirkl94@users.noreply.github.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avataramirkl94 <203507526+amirkl94@users.noreply.github.com>
parent 6d86fde0
...@@ -148,7 +148,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -148,7 +148,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True # NOTE(rob): discovered an IMA with this combination. Needs investigation.
return not moe_parallel_config.use_fi_all2allv_kernels
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return True return True
......
...@@ -103,6 +103,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -103,6 +103,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts: int, num_experts: int,
a1_scale: torch.Tensor | None, a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> Callable: ) -> Callable:
has_scales = token_scales is not None has_scales = token_scales is not None
...@@ -174,6 +175,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -174,6 +175,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights, expert_topk_weights,
a1_scale, a1_scale,
quant_config, quant_config,
defer_input_quant=defer_input_quant,
) )
def _receiver( def _receiver(
...@@ -187,6 +189,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -187,6 +189,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights: torch.Tensor | None, expert_topk_weights: torch.Tensor | None,
a1_scale: torch.Tensor | None, a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
if event.event is not None: if event.event is not None:
event.current_stream_wait() event.current_stream_wait()
...@@ -221,14 +224,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -221,14 +224,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_num_tokens_per_expert_list, device=expert_x.device expert_num_tokens_per_expert_list, device=expert_x.device
) )
# Dispatch and Quant # * For non-block quant, dispatch in b16 and quantize now as
# DeepEP kernels only support dispatching block-quantized # DeepEP kernels only support dispatching block scales.
# activation scales. # * For expert kernels that require unquantized inputs,
# Dispatch in bfloat16 and quantize afterwards # defer quantization to FusedMoEExpertsPermuteUnpermute.
if not quant_config.is_block_quantized: if not quant_config.is_block_quantized and not defer_input_quant:
# Quantize after dispatch. # Quantize after dispatch.
expert_x_scale = None expert_x_scale = None
if expert_x.numel() != 0: if expert_x.numel() != 0:
# TODO: support per_act_token_quant,
expert_x, expert_x_scale = moe_kernel_quantize_input( expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x, expert_x,
a1_scale, a1_scale,
...@@ -257,6 +261,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -257,6 +261,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.ReceiverType: ) -> mk.ReceiverType:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
...@@ -266,8 +271,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -266,8 +271,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
if quant_config.is_block_quantized: # * DeepEP only supports fp8 block scales so quantize
# Quant and Dispatch # before the dispatch for these models.
# * For all other quantization, dispatch after.
# * For expert kernels that require unquantized inputs,
# defer quantization to FusedMoEExpertsPermuteUnpermute.
if quant_config.is_block_quantized and not defer_input_quant:
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
quant_config.a1_scale, quant_config.a1_scale,
...@@ -281,7 +290,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -281,7 +290,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else: else:
a1q = a1 a1q = a1
a1q_scale = None a1q_scale = None
a1_post_scale = quant_config.a1_scale a1_post_scale = (
quant_config.a1_gscale
if quant_config.quant_dtype == "nvfp4"
else quant_config.a1_scale
)
return self._do_dispatch( return self._do_dispatch(
tokens=a1q, tokens=a1q,
...@@ -291,6 +304,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -291,6 +304,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts=num_experts, num_experts=num_experts,
a1_scale=a1_post_scale, a1_scale=a1_post_scale,
quant_config=quant_config, quant_config=quant_config,
defer_input_quant=defer_input_quant,
) )
def prepare( def prepare(
...@@ -302,6 +316,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -302,6 +316,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
receiver = self.prepare_async( receiver = self.prepare_async(
a1, a1,
...@@ -311,6 +326,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -311,6 +326,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
quant_config, quant_config,
defer_input_quant,
) )
return receiver() return receiver()
......
...@@ -242,7 +242,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -242,7 +242,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]: ) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
hidden_size = a1.size(1) hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, ( assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, (
f"Hidden Size {hidden_size} not in supported list of hidden sizes" f"Hidden Size {hidden_size} not in supported list of hidden sizes"
...@@ -344,7 +351,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -344,7 +351,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
hook, receiver = self.prepare_async( hook, receiver = self.prepare_async(
a1, a1,
topk_weights, topk_weights,
......
...@@ -4,18 +4,12 @@ ...@@ -4,18 +4,12 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_dp_group, get_ep_group from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.base_device_communicator import ( from vllm.distributed.device_communicators.base_device_communicator import (
All2AllManagerBase, All2AllManagerBase,
) )
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave from vllm.utils.flashinfer import nvfp4_block_scale_interleave
...@@ -24,22 +18,16 @@ def get_local_sizes(): ...@@ -24,22 +18,16 @@ def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""Base class for FlashInfer MoE prepare and finalize operations.""" """Base class for FlashInfer MoE prepare and finalize operations."""
def __init__( def __init__(
self, self,
use_dp: bool,
num_dispatchers: int = 1, num_dispatchers: int = 1,
use_deepseek_fp8_block_scale: bool = False,
): ):
super().__init__() super().__init__()
self.num_dispatchers_ = num_dispatchers self.num_dispatchers_ = num_dispatchers
self.use_dp = use_dp self.all2all_manager = get_ep_group().device_communicator.all2all_manager
self.local_tokens = None
# Toggle for DeepSeek-style FP8 block-scale path where activations are
# not quantized here and weight block scales are consumed by the kernel.
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
...@@ -72,24 +60,6 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -72,24 +60,6 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
a1.mul_(topk_weights.to(a1.dtype)) a1.mul_(topk_weights.to(a1.dtype))
class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize):
"""FlashInfer implementation using AllToAll communication."""
def __init__(
self,
use_dp: bool,
num_dispatchers: int = 1,
use_deepseek_fp8_block_scale: bool = False,
):
super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
self.alltoall_info = None
# Initialize all2all_manager only for DP case
self.all2all_manager = None
if self.use_dp:
self.all2all_manager = get_ep_group().device_communicator.all2all_manager
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
...@@ -99,134 +69,28 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina ...@@ -99,134 +69,28 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
self._apply_router_weight_on_input( self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input a1, topk_weights, topk_ids, apply_router_weight_on_input
) )
global_num_tokens_cpu = get_local_sizes()
top_k = topk_ids.size(1)
if not self.use_dp: (self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = (
# Non-DP case: quantize activations unless using block-scale path flashinfer_alltoall_dispatch(
if not self.use_deepseek_fp8_block_scale:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=not self.use_dp,
)
else:
a1q = a1
a1q_scale = None
else:
# DP case: use FlashInfer AllToAll
global_num_tokens_cpu = get_local_sizes()
top_k = topk_ids.size(1)
(self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = (
flashinfer_alltoall_dispatch(
self.all2all_manager,
global_num_tokens_cpu,
a1,
quant_config.a1_gscale,
topk_ids,
topk_weights,
top_k,
num_experts,
quant_config,
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
)
)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if self.use_dp:
top_k = topk_ids.size(1)
token_count = output.shape[0]
fused_expert_output = flashinfer_alltoall_combine(
self.all2all_manager, self.all2all_manager,
fused_expert_output, global_num_tokens_cpu,
top_k=top_k,
token_count=token_count,
alltoall_info=self.alltoall_info,
)
output.copy_(fused_expert_output)
class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize):
def __init__(
self,
use_dp: bool,
num_dispatchers: int = 1,
use_deepseek_fp8_block_scale: bool = False,
):
super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input
)
is_nvfp4 = quant_config.quant_dtype == "nvfp4"
if not self.use_dp and is_nvfp4:
return a1, None, None, topk_ids, topk_weights
if not self.use_deepseek_fp8_block_scale:
a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale, quant_config.a1_gscale,
quant_config.quant_dtype, topk_ids,
quant_config.per_act_token_quant, topk_weights,
quant_config.block_shape, top_k,
is_fp4_scale_swizzled=not self.use_dp, num_experts,
quant_config,
defer_input_quant=defer_input_quant,
) )
else: )
# Block-scale path: pass activations through, omit per-token scales
a1q = a1
a1q_scale = None
if self.use_dp:
# Build gather list conditionally - omit a1q_scale if None
# (block-scale path)
gather_list = [topk_weights, topk_ids, a1q]
if a1q_scale is not None:
gather_list.append(a1q_scale)
gathered = get_dp_group().all_gatherv(
gather_list,
dim=0,
sizes=get_local_sizes(),
)
topk_weights, topk_ids, a1q, a1q_scale = gathered
else:
gathered = get_dp_group().all_gatherv(
gather_list,
dim=0,
sizes=get_local_sizes(),
)
topk_weights, topk_ids, a1q = gathered
a1q_scale = None
if is_nvfp4 and a1q_scale is not None:
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights return a1q, a1q_scale, None, topk_ids, topk_weights
...@@ -239,12 +103,15 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin ...@@ -239,12 +103,15 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce, weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None: ) -> None:
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP) top_k = topk_ids.size(1)
token_count = output.shape[0]
if self.use_dp: fused_expert_output = flashinfer_alltoall_combine(
fused_expert_output = get_dp_group().reduce_scatterv( self.all2all_manager,
fused_expert_output, dim=0, sizes=get_local_sizes() fused_expert_output,
) top_k=top_k,
token_count=token_count,
alltoall_info=self.alltoall_info,
)
output.copy_(fused_expert_output) output.copy_(fused_expert_output)
...@@ -258,7 +125,7 @@ def flashinfer_alltoall_dispatch( ...@@ -258,7 +125,7 @@ def flashinfer_alltoall_dispatch(
top_k: int, top_k: int,
num_experts: int, num_experts: int,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
use_deepseek_fp8_block_scale: bool = False, defer_input_quant: bool = False,
): ):
from flashinfer.comm.trtllm_alltoall import MnnvlMoe from flashinfer.comm.trtllm_alltoall import MnnvlMoe
...@@ -288,15 +155,20 @@ def flashinfer_alltoall_dispatch( ...@@ -288,15 +155,20 @@ def flashinfer_alltoall_dispatch(
) )
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype) topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
if not use_deepseek_fp8_block_scale: if not defer_input_quant:
x, x_sf = moe_kernel_quantize_input( x, x_sf = moe_kernel_quantize_input(
x, x,
gs, gs,
quant_config.quant_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.per_act_token_quant,
quant_config.block_shape, quant_config.block_shape,
is_fp4_scale_swizzled=False, # delay swizzle to after comm # NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled=False,
) )
x = MnnvlMoe.mnnvl_moe_alltoallv( x = MnnvlMoe.mnnvl_moe_alltoallv(
x, x,
alltoall_info, alltoall_info,
...@@ -312,7 +184,11 @@ def flashinfer_alltoall_dispatch( ...@@ -312,7 +184,11 @@ def flashinfer_alltoall_dispatch(
ep_rank, ep_rank,
ep_size, ep_size,
) )
# Swizzle after the A2A if nvfp4.
if quant_config.quant_dtype == "nvfp4": if quant_config.quant_dtype == "nvfp4":
if x_sf.element_size() == 1:
x_sf = x_sf.view(torch.uint8)
x_sf = nvfp4_block_scale_interleave(x_sf) x_sf = nvfp4_block_scale_interleave(x_sf)
else: else:
# Block-scale path: pass activations through without quantization # Block-scale path: pass activations through without quantization
...@@ -348,26 +224,3 @@ def flashinfer_alltoall_combine( ...@@ -348,26 +224,3 @@ def flashinfer_alltoall_combine(
top_k=top_k, top_k=top_k,
token_count=token_count, token_count=token_count,
) )
def create_flashinfer_prepare_finalize(
use_dp: bool,
use_nvfp4: bool = False,
enable_alltoallv: bool = False,
use_deepseek_fp8_block_scale: bool = False,
) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP:
"""Factory function to create the appropriate FlashInfer implementation."""
if use_dp:
if enable_alltoallv:
assert use_nvfp4
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
return FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=True,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
else:
# CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization
# in a single call with the MoE experts kernel.
defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4
return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant)
...@@ -78,16 +78,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -78,16 +78,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# - skip input activation quantization (kernel applies scaling) # - skip input activation quantization (kernel applies scaling)
self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized
@staticmethod @property
def expects_unquantized_inputs( def expects_unquantized_inputs(self) -> bool:
moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig return self.quant_config.use_fp8_w8a8 and self.quant_config.is_block_quantized
) -> bool:
# NVFP4 TP kernels and FP8 block-quantized kernels apply
# input quantization inside FusedMoEPermuteExpertsUnpermute.
return (
quant_config.use_nvfp4_w4a4
and not moe_config.moe_parallel_config.use_all2all_kernels
) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized)
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
...@@ -144,10 +137,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -144,10 +137,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FLASHINFER_CUTLASS currently uses its down P/F, which does not # FLASHINFER_CUTLASS currently uses its down P/F, which does not
# work with SP. This will be removed in follow up after we get # work with SP. This will be removed in follow up after we get
# rid of the FlashInfer specific P/F function. # rid of the FlashInfer specific P/F function.
return ( # TODO: the per-tensor fp8 kernels don't work with MNNVL FI A2As.
moe_parallel_config.dp_size == 1 return not moe_parallel_config.is_sequence_parallel
or moe_parallel_config.dp_size == moe_parallel_config.ep_size
)
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
...@@ -194,8 +185,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -194,8 +185,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
""" """
workspace1 = (M, K) workspace1 = (M, K)
workspace2 = (0,) workspace2 = (0,)
# For TP, the quantization is fused with fused_moe call. # For NVFP4, the output is stored in a packed int8 format,
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K) # so the actual hidden dim is 2x the size of K here.
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
# The workspace is determined by `aq`, since it comes after any # The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation. # potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape) return (workspace1, workspace2, output_shape)
......
...@@ -533,7 +533,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -533,7 +533,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert a1.dim() == 2 assert a1.dim() == 2
assert topk_ids.dim() == 2 assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0) assert topk_ids.size(0) == a1.size(0)
......
...@@ -593,7 +593,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -593,7 +593,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True return not moe_parallel_config.use_fi_all2allv_kernels
@property @property
def quant_type_id(self) -> int: def quant_type_id(self) -> int:
......
...@@ -1951,7 +1951,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1951,7 +1951,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True return not moe_parallel_config.use_fi_all2allv_kernels
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return True return True
......
...@@ -5,6 +5,7 @@ from abc import abstractmethod ...@@ -5,6 +5,7 @@ from abc import abstractmethod
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
...@@ -26,6 +27,19 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -26,6 +27,19 @@ class FusedMoEMethodBase(QuantizeMethodBase):
super().__init__() super().__init__()
self.moe: FusedMoEConfig = moe self.moe: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None self.moe_quant_config: FusedMoEQuantConfig | None = None
self.moe_mk: mk.FusedMoEModularKernel | None = None
@property
def supports_internal_mk(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
return self.moe_mk is not None
@property
def mk_owns_shared_expert(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
return self.moe_mk is not None and self.moe_mk.shared_experts is not None
@abstractmethod @abstractmethod
def create_weights( def create_weights(
...@@ -91,6 +105,8 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -91,6 +105,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@property @property
def topk_indices_dtype(self) -> torch.dtype | None: def topk_indices_dtype(self) -> torch.dtype | None:
if self.moe_mk is not None:
return self.moe_mk.prepare_finalize.topk_indices_dtype()
return None return None
@property @property
......
...@@ -30,11 +30,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -30,11 +30,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
): ):
super().__init__(old_quant_method.moe) super().__init__(old_quant_method.moe)
self.moe_quant_config = old_quant_method.moe_quant_config self.moe_quant_config = old_quant_method.moe_quant_config
self.fused_experts = experts self.moe_mk = experts
self.disable_expert_map = getattr( self.disable_expert_map = getattr(
old_quant_method, old_quant_method,
"disable_expert_map", "disable_expert_map",
not self.fused_experts.supports_expert_map(), not self.moe_mk.supports_expert_map(),
) )
self.old_quant_method = old_quant_method self.old_quant_method = old_quant_method
assert not self.old_quant_method.is_monolithic assert not self.old_quant_method.is_monolithic
...@@ -57,10 +57,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -57,10 +57,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
), ),
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
return self.fused_experts.prepare_finalize.topk_indices_dtype()
@property @property
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
return self.old_quant_method.supports_eplb return self.old_quant_method.supports_eplb
...@@ -96,7 +92,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -96,7 +92,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.fused_experts( assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
......
...@@ -571,9 +571,6 @@ class FusedMoE(CustomOp): ...@@ -571,9 +571,6 @@ class FusedMoE(CustomOp):
device=vllm_config.device_config.device, device=vllm_config.device_config.device,
routing_method=self.routing_method_type, routing_method=self.routing_method_type,
) )
self.moe_config_use_flashinfer_cutlass_kernels = (
self.moe_config.use_flashinfer_cutlass_kernels
)
if self.use_mori_kernels: if self.use_mori_kernels:
assert self.rocm_aiter_fmoe_enabled, ( assert self.rocm_aiter_fmoe_enabled, (
"Mori needs to be used with aiter fused_moe for now." "Mori needs to be used with aiter fused_moe for now."
...@@ -646,6 +643,11 @@ class FusedMoE(CustomOp): ...@@ -646,6 +643,11 @@ class FusedMoE(CustomOp):
# This is called after all weight loading and post-processing, so it # This is called after all weight loading and post-processing, so it
# should be safe to swap out the quant_method. # should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None: def maybe_init_modular_kernel(self) -> None:
# NOTE(rob): WIP refactor. For quant methods that own the MK
# we create the MK during process_weights_after_loading.
if self.quant_method.supports_internal_mk or self.quant_method.is_monolithic:
return None
self.ensure_moe_quant_config_init() self.ensure_moe_quant_config_init()
# routing_tables only needed for round-robin expert placement with # routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend. # DeepEP all2all backend.
...@@ -728,14 +730,6 @@ class FusedMoE(CustomOp): ...@@ -728,14 +730,6 @@ class FusedMoE(CustomOp):
def use_mori_kernels(self): def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels return self.moe_parallel_config.use_mori_kernels
@property
def use_flashinfer_cutlass_kernels(self):
return (
self.moe_quant_config is not None
and self.moe_quant_config.quant_dtype == "nvfp4"
and self.moe_config_use_flashinfer_cutlass_kernels
)
@property @property
def use_marlin_kernels(self): def use_marlin_kernels(self):
return getattr(self.quant_method, "use_marlin", False) return getattr(self.quant_method, "use_marlin", False)
...@@ -746,7 +740,7 @@ class FusedMoE(CustomOp): ...@@ -746,7 +740,7 @@ class FusedMoE(CustomOp):
self.moe_parallel_config.use_pplx_kernels self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels or self.moe_parallel_config.use_mori_kernels
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) or self.moe_parallel_config.use_fi_all2allv_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK ) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@property @property
...@@ -1532,7 +1526,7 @@ class FusedMoE(CustomOp): ...@@ -1532,7 +1526,7 @@ class FusedMoE(CustomOp):
assert self.quant_method is not None assert self.quant_method is not None
return ( return (
isinstance(self.quant_method, FusedMoEModularMethod) isinstance(self.quant_method, FusedMoEModularMethod)
and self.quant_method.fused_experts.output_is_reduced() and self.quant_method.moe_mk.output_is_reduced() # type: ignore[union-attr]
) )
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
...@@ -1765,7 +1759,7 @@ class FusedMoE(CustomOp): ...@@ -1765,7 +1759,7 @@ class FusedMoE(CustomOp):
self.ensure_dp_chunking_init() self.ensure_dp_chunking_init()
has_separate_shared_experts = ( has_separate_shared_experts = (
not isinstance(self.quant_method, FusedMoEModularMethod) not self.quant_method.mk_owns_shared_expert
and self.shared_experts is not None and self.shared_experts is not None
) )
...@@ -1789,8 +1783,10 @@ class FusedMoE(CustomOp): ...@@ -1789,8 +1783,10 @@ class FusedMoE(CustomOp):
hidden_states, router_logits, has_separate_shared_experts hidden_states, router_logits, has_separate_shared_experts
) )
do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance( # NOTE(rob): once we finish migrating all the quant methods to use
self.quant_method, FusedMoEModularMethod # MKs, we can remove the naive dispatch/combine path from here.
do_naive_dispatch_combine = (
self.dp_size > 1 and not self.quant_method.supports_internal_mk
) )
ctx = get_forward_context() ctx = get_forward_context()
...@@ -1818,7 +1814,7 @@ class FusedMoE(CustomOp): ...@@ -1818,7 +1814,7 @@ class FusedMoE(CustomOp):
else: else:
hidden_states_to_dispatch = hidden_states hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch( dispatch_res = get_ep_group().dispatch_router_logits(
hidden_states_to_dispatch, hidden_states_to_dispatch,
router_logits, router_logits,
self.is_sequence_parallel, self.is_sequence_parallel,
......
...@@ -180,6 +180,7 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -180,6 +180,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> PrepareResultType: ) -> PrepareResultType:
""" """
Perform any quantization (and/or) dispatching needed for this kernel. Perform any quantization (and/or) dispatching needed for this kernel.
...@@ -192,6 +193,9 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -192,6 +193,9 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the - apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching. activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts. - quant_config: Quantization info provided by the fused experts.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
in cases where the compute kernel expects unquantized inputs
Returns a tuple of: Returns a tuple of:
- quantized + dispatched a. - quantized + dispatched a.
...@@ -220,6 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -220,6 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> tuple[Callable, ReceiverType] | ReceiverType: ) -> tuple[Callable, ReceiverType] | ReceiverType:
""" """
Perform any quantization (and/or) dispatching needed for this kernel Perform any quantization (and/or) dispatching needed for this kernel
...@@ -235,6 +240,9 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -235,6 +240,9 @@ class FusedMoEPrepareAndFinalize(ABC):
space to the local expert space of the expert parallel shard. space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the - apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching. activations, before quantization + dispatching.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
in cases where the compute kernel expects unquantized inputs
Returns a callback or a hook callback pair that when invoked waits for Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as results from other workers and has the same return signature as
...@@ -407,10 +415,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -407,10 +415,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
@staticmethod @property
def expects_unquantized_inputs( def expects_unquantized_inputs(self) -> bool:
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
""" """
Whether or not the PrepareFinalize should defer input quantization Whether or not the PrepareFinalize should defer input quantization
in the prepare step. If True, then the Experts kernel will in the prepare step. If True, then the Experts kernel will
...@@ -1069,6 +1075,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1069,6 +1075,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
) )
else: else:
# Overlap shared expert compute with all2all dispatch. # Overlap shared expert compute with all2all dispatch.
...@@ -1081,6 +1088,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1081,6 +1088,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
) )
# TODO(lucas): refactor this in the alternative schedules followup # TODO(lucas): refactor this in the alternative schedules followup
......
...@@ -58,6 +58,7 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -58,6 +58,7 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
""" """
Returns a tuple of: Returns a tuple of:
...@@ -69,6 +70,11 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -69,6 +70,11 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
- Optional dispatched expert topk IDs - Optional dispatched expert topk IDs
- Optional dispatched expert topk weight - Optional dispatched expert topk weight
""" """
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"mori does not support apply_router_weight_on_input=True now." "mori does not support apply_router_weight_on_input=True now."
) )
......
...@@ -8,6 +8,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -8,6 +8,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -17,9 +20,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -17,9 +20,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm, is_supported_config_trtllm,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend, FlashinferMoeBackend,
get_flashinfer_moe_backend, get_flashinfer_moe_backend,
...@@ -465,68 +465,52 @@ def make_fp8_moe_quant_config( ...@@ -465,68 +465,52 @@ def make_fp8_moe_quant_config(
) )
def make_fp8_moe_kernel_for_mkm( def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
prepare_finalize: mk.FusedMoEPrepareAndFinalize, fp8_backend: Fp8MoeBackend,
) -> mk.FusedMoEPermuteExpertsUnpermute: routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> tuple[mk.FusedMoEModularKernel, bool]:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None assert max_num_tokens is not None
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=moe_quant_config,
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
) )
else: else:
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=moe_quant_config,
) )
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
fp8_backend: Fp8MoeBackend,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> tuple[mk.FusedMoEModularKernel, bool]:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"Fp8 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# NOTE(rob): we only want the mk to control the shared_expert # NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in # if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class. # the new MoE runner class.
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=None, shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config, moe_parallel_config=moe_config.moe_parallel_config,
) )
......
...@@ -7,6 +7,9 @@ import torch ...@@ -7,6 +7,9 @@ import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -14,9 +17,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -14,9 +17,6 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config, nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config, nvfp4_w4a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_supported_config_trtllm, is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass, prepare_nvfp4_moe_layer_for_fi_or_cutlass,
...@@ -391,67 +391,51 @@ def make_nvfp4_moe_quant_config( ...@@ -391,67 +391,51 @@ def make_nvfp4_moe_quant_config(
) )
def make_nvfp4_moe_kernel_for_mkm( def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
prepare_finalize: mk.FusedMoEPrepareAndFinalize, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPermuteExpertsUnpermute: shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None assert max_num_tokens is not None
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=moe_quant_config,
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
) )
else: else:
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=moe_quant_config,
) )
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> mk.FusedMoEModularKernel:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"NvFP4 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# NOTE(rob): we only want the mk to control the shared_expert # NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in # if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class. # the new MoE runner class.
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=None, shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config, moe_parallel_config=moe_config.moe_parallel_config,
) )
......
...@@ -106,7 +106,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -106,7 +106,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]: ) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
num_tokens = a1.size(0) # M num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K hidden_dim = a1.size(-1) # K
...@@ -274,6 +281,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -274,6 +281,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
hook, receiver = self.prepare_async( hook, receiver = self.prepare_async(
a1, a1,
...@@ -283,6 +291,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -283,6 +291,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
quant_config, quant_config,
defer_input_quant=defer_input_quant,
) )
hook() hook()
return receiver() return receiver()
......
...@@ -4,18 +4,25 @@ ...@@ -4,18 +4,25 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_ep_group
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceContiguous,
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
) )
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
def __init__(self, defer_input_quant: bool = False) -> None: def __init__(
self,
is_sequence_parallel: bool = False,
num_dispatchers: int = 1,
) -> None:
super().__init__() super().__init__()
self.defer_input_quant = defer_input_quant self.is_sequence_parallel = is_sequence_parallel
self._num_dispatchers = num_dispatchers
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
...@@ -27,6 +34,113 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -27,6 +34,113 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def topk_indices_dtype(self) -> torch.dtype | None: def topk_indices_dtype(self) -> torch.dtype | None:
return None return None
def num_dispatchers(self) -> int:
return self._num_dispatchers
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quantization to the MoE kernel.
use_nvfp4 = quant_config.use_nvfp4_w4a4
if defer_input_quant:
a1q = a1
a1q_scale = None
else:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled=False,
)
# Skip gathering scales if we have static quantization
# (the scale is a scalar, replicated on all ranks) or
# if quantization is deferred.
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
scales = None if skip_gather_scales else [a1q_scale]
res = get_ep_group().dispatch(
a1q,
topk_weights,
topk_ids,
is_sequence_parallel=self.is_sequence_parallel,
extra_tensors=scales,
)
if skip_gather_scales:
a1q, topk_weights, topk_ids = res
else:
a1q, topk_weights, topk_ids, scales = res
assert scales is not None and len(scales) == 1
a1q_scale = scales[0]
if quant_config.quant_dtype == "nvfp4":
assert a1q_scale is not None
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
out = weight_and_reduce_impl.apply(
output=None,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
output.copy_(
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
)
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return 1 return 1
...@@ -42,6 +156,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -42,6 +156,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
...@@ -54,12 +169,17 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -54,12 +169,17 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
# Defer input quant to moe kernel for backends (e.g. AITER, FI) # Defer input quant to moe kernel for backends (e.g. AITER, FI)
# which use a single kernel call for quant + experts. # which use a single kernel call for quant + experts.
if self.defer_input_quant: if defer_input_quant:
return a1, None, None, None, None return a1, None, None, None, None
input_sf = (
quant_config.a1_gscale
if quant_config.use_nvfp4_w4a4
else quant_config.a1_scale
)
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
quant_config.a1_scale, input_sf,
quant_config.quant_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.per_act_token_quant,
quant_config.block_shape, quant_config.block_shape,
......
...@@ -287,17 +287,14 @@ def rocm_aiter_fused_experts( ...@@ -287,17 +287,14 @@ def rocm_aiter_fused_experts(
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
@staticmethod
def expects_unquantized_inputs(
fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
# AITER fused MoE kernels handle input quantization internally.
return True
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
return rocm_aiter_ops.is_fused_moe_enabled() return rocm_aiter_ops.is_fused_moe_enabled()
...@@ -329,7 +326,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -329,7 +326,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True return not moe_parallel_config.use_fi_all2allv_kernels
def supports_expert_map(self): def supports_expert_map(self):
return True return True
......
...@@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE): ...@@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE):
use_overlapped use_overlapped
and not ( and not (
(self.enable_eplb and backend != "allgather_reducescatter") (self.enable_eplb and backend != "allgather_reducescatter")
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) or self.moe_parallel_config.use_fi_all2allv_kernels
) )
and self._shared_experts is not None and self._shared_experts is not None
) )
......
...@@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -53,7 +52,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( ...@@ -53,7 +52,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_mxfp4_moe_quant_config, make_mxfp4_moe_quant_config,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
) )
...@@ -67,7 +65,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( ...@@ -67,7 +65,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe, process_fp8_input_tensor_strategy_moe,
...@@ -243,7 +240,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -243,7 +240,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.group_size = 32 self.group_size = 32
self.mxfp4_backend = NvFp4MoeBackend.MARLIN self.mxfp4_backend = NvFp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts self.experts_cls = MarlinExperts
self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights( def create_weights(
self, self,
...@@ -320,7 +316,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -320,7 +316,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False layer.w13_weight_packed.data, requires_grad=False
) )
...@@ -335,10 +331,12 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -335,10 +331,12 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None: if self.moe_quant_config is not None:
self.kernel = make_nvfp4_moe_kernel( self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
def apply( def apply(
...@@ -348,8 +346,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -348,8 +346,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None assert self.moe_mk is not None
return self.kernel( return self.moe_mk(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -380,19 +378,10 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -380,19 +378,10 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
activation_key=None if use_a16 else kNvfp4Dynamic, activation_key=None if use_a16 else kNvfp4Dynamic,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -506,7 +495,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -506,7 +495,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
) )
set_weight_attrs(w2_input_scale, extra_weight_attrs) set_weight_attrs(w2_input_scale, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
""" """
Convert NVFP4 MoE weights into kernel format and setup the kernel. Convert NVFP4 MoE weights into kernel format and setup the kernel.
""" """
...@@ -572,48 +561,33 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -572,48 +561,33 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and ( if self.moe_quant_config:
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.kernel = make_nvfp4_moe_kernel( self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: raise ValueError(
return None f"{self.__class__.__name__} uses the new modular kernel initialization "
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: "logic. This function should not be called."
# For no-EP case, don't use the MKM framework. )
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None raise ValueError(
assert self.experts_cls is not None f"{self.__class__.__name__} uses the new modular kernel initialization "
return make_nvfp4_moe_kernel_for_mkm( "logic. This function should not be called."
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -684,8 +658,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -684,8 +658,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
) )
else: else:
assert self.kernel is not None assert self.moe_mk is not None
return self.kernel( return self.moe_mk(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -759,15 +733,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -759,15 +733,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
allow_vllm_cutlass=True, allow_vllm_cutlass=True,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -927,7 +892,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -927,7 +892,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
# Allow for accessing weights and scales in standard way. # Allow for accessing weights and scales in standard way.
w13 = layer.w13_weight w13 = layer.w13_weight
w2 = layer.w2_weight w2 = layer.w2_weight
...@@ -989,49 +954,34 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -989,49 +954,34 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and ( if self.moe_quant_config:
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel( self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: raise ValueError(
return None f"{self.__class__.__name__} uses the new modular kernel initialization "
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: "logic. This function should not be called."
# For no-EP case, don't use the MKM framework. )
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None raise ValueError(
assert self.experts_cls is not None f"{self.__class__.__name__} uses the new modular kernel initialization "
return make_fp8_moe_kernel_for_mkm( "logic. This function should not be called."
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -1120,8 +1070,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1120,8 +1070,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert not self.is_monolithic
assert self.kernel is not None assert self.moe_mk is not None
return self.kernel( return self.moe_mk(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment