Unverified Commit 97995f63 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Create MK for TRTLLM Kernels (#32564)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
parent 881a6b01
...@@ -18,7 +18,7 @@ def get_local_sizes(): ...@@ -18,7 +18,7 @@ def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""Base class for FlashInfer MoE prepare and finalize operations.""" """Base class for FlashInfer MoE prepare and finalize operations."""
def __init__( def __init__(
...@@ -185,8 +185,8 @@ def flashinfer_alltoall_dispatch( ...@@ -185,8 +185,8 @@ def flashinfer_alltoall_dispatch(
ep_size, ep_size,
) )
# Swizzle after the A2A if nvfp4. # Swizzle after the A2A if MoE kernel expects swizzled scales.
if quant_config.quant_dtype == "nvfp4": if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
if x_sf.element_size() == 1: if x_sf.element_size() == 1:
x_sf = x_sf.view(torch.uint8) x_sf = x_sf.view(torch.uint8)
x_sf = nvfp4_block_scale_interleave(x_sf) x_sf = nvfp4_block_scale_interleave(x_sf)
......
...@@ -30,7 +30,7 @@ from vllm.utils.flashinfer import ( ...@@ -30,7 +30,7 @@ from vllm.utils.flashinfer import (
logger = init_logger(__name__) logger = init_logger(__name__)
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
......
...@@ -60,7 +60,7 @@ def is_valid_flashinfer_cutlass_fused_moe( ...@@ -60,7 +60,7 @@ def is_valid_flashinfer_cutlass_fused_moe(
return True return True
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): class FlashInferExperts(mk.FusedMoEExpertsModular):
def __init__( def __init__(
self, self,
moe_config: mk.FusedMoEConfig, moe_config: mk.FusedMoEConfig,
......
...@@ -10,16 +10,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -10,16 +10,6 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig, FusedMoEParallelConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
...@@ -39,49 +29,10 @@ def _supports_no_act_and_mul() -> bool: ...@@ -39,49 +29,10 @@ def _supports_no_act_and_mul() -> bool:
return True return True
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor and Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: MoEActivation) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
def _supports_routing_method(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
routing_method: RoutingMethodType,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
# NOTE(dbari): Default is not implemented and should not be enabled until it is
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
else:
raise ValueError("Unsupported quantization scheme.")
def _supports_routing_method_bf16( def _supports_routing_method_bf16(
routing_method: RoutingMethodType, routing_method: RoutingMethodType,
) -> bool: ) -> bool:
...@@ -99,62 +50,6 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo ...@@ -99,62 +50,6 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
return not moe_parallel_config.enable_eplb return not moe_parallel_config.enable_eplb
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def is_supported_config_trtllm_fp8(
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason(f"current device {current_platform.device_name}")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_quant_scheme(weight_key, activation_key):
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
elif not _supports_routing_method(
weight_key, activation_key, moe_config.routing_method
):
return False, _make_reason(f"routing method {moe_config.routing_method}")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason(f"activation format {activation_format}")
elif not _supports_router_logits_dtype(
moe_config.router_logits_dtype, moe_config.routing_method
):
return False, _make_reason(
"float32 router_logits with non-DeepSeekV3 routing "
f"{moe_config.router_logits_dtype}x{moe_config.routing_method}"
)
return True, None
def is_supported_config_trtllm_bf16( def is_supported_config_trtllm_bf16(
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
activation_format: mk.FusedMoEActivationFormat, activation_format: mk.FusedMoEActivationFormat,
...@@ -183,199 +78,6 @@ def is_supported_config_trtllm_bf16( ...@@ -183,199 +78,6 @@ def is_supported_config_trtllm_bf16(
return True, None return True, None
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routing_method_type: int,
routed_scaling: float | None = 1.0,
) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
assert top_k <= global_num_experts
assert top_k <= 10
assert global_num_experts % 4 == 0
assert block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 512
# The DeepSeekV3 routing method requires float32 router logits.
if routing_method_type == RoutingMethodType.DeepSeekV3:
routing_logits = routing_logits.to(torch.float32)
if routing_bias is not None:
routing_bias = routing_bias.to(x.dtype)
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
return flashinfer_trtllm_fp8_block_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale_inv,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale_inv,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
routing_method_type=routing_method_type,
use_shuffled_weight=False,
)
def flashinfer_fused_moe_blockscale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routing_method_type: int,
routed_scaling: float = 1.0,
) -> torch.Tensor:
return torch.empty_like(x)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
def fi_trtllm_fp8_per_tensor_moe(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
activation_type: int,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
quant_hidden_states, _ = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False,
)
from flashinfer.fused_moe.core import ActivationType
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe
# The DeepSeekV3 routing method requires float32 router logits.
if routing_method_type == RoutingMethodType.DeepSeekV3:
routing_logits = routing_logits.to(torch.float32)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=quant_hidden_states,
gemm1_weights=gemm1_weights,
output1_scales_scalar=output1_scales_scalar,
output1_scales_gate_scalar=output1_scales_gate_scalar,
gemm2_weights=gemm2_weights,
output2_scales_scalar=output2_scales_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
routing_method_type=routing_method_type,
# TODO: enum type Required for flashinfer==0.6.3, remove with update
# https://github.com/flashinfer-ai/flashinfer/pull/2508
activation_type=ActivationType(activation_type),
)
def fi_trtllm_fp8_per_tensor_moe_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
activation_type: int,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="fi_trtllm_fp8_per_tensor_moe",
op_func=fi_trtllm_fp8_per_tensor_moe,
mutates_args=["hidden_states"],
fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
def flashinfer_fused_moe_bf16( def flashinfer_fused_moe_bf16(
routing_logits: torch.Tensor, routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None, routing_bias: torch.Tensor | None,
......
...@@ -489,7 +489,7 @@ def invoke_moe_batched_triton_kernel( ...@@ -489,7 +489,7 @@ def invoke_moe_batched_triton_kernel(
) )
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
""" """
A reference prepare/finalize class that reorganizes the tokens into A reference prepare/finalize class that reorganizes the tokens into
expert batched format, i.e. E x max_num_tokens x K. This is the format expert batched format, i.e. E x max_num_tokens x K. This is the format
...@@ -645,7 +645,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -645,7 +645,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): class NaiveBatchedExperts(mk.FusedMoEExpertsModular):
""" """
A reference MoE expert class that operates on expert batched format, A reference MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the batched i.e. E x max_num_tokens x K. This is the format that the batched
...@@ -877,7 +877,7 @@ def batched_moe_kernel_quantize_input( ...@@ -877,7 +877,7 @@ def batched_moe_kernel_quantize_input(
return A_q, A_q_scale return A_q, A_q_scale
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class BatchedTritonExperts(mk.FusedMoEExpertsModular):
""" """
A Triton based MoE expert class that operates on expert batched format, A Triton based MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the batched i.e. E x max_num_tokens x K. This is the format that the batched
......
...@@ -526,7 +526,7 @@ def batched_fused_marlin_moe( ...@@ -526,7 +526,7 @@ def batched_fused_marlin_moe(
return output return output
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): class MarlinExpertsBase(mk.FusedMoEExpertsModular):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
......
...@@ -1736,7 +1736,7 @@ def fused_experts_impl( ...@@ -1736,7 +1736,7 @@ def fused_experts_impl(
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
# This needs separate memory since it's used concurrently with cache1 # This needs separate memory since it's used concurrently with cache1
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation( activation_out_dim = mk.FusedMoEExpertsModular.adjust_N_for_activation(
N, activation_enum N, activation_enum
) )
intermediate_cache2 = torch.empty( intermediate_cache2 = torch.empty(
...@@ -1924,7 +1924,7 @@ def fused_experts_impl( ...@@ -1924,7 +1924,7 @@ def fused_experts_impl(
return out_hidden_states return out_hidden_states
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class TritonExperts(mk.FusedMoEExpertsModular):
"""Triton-based fused MoE expert implementation.""" """Triton-based fused MoE expert implementation."""
def __init__( def __init__(
......
...@@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute, FusedMoEExpertsModular,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalizeModular,
) )
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
...@@ -27,19 +27,21 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -27,19 +27,21 @@ class FusedMoEMethodBase(QuantizeMethodBase):
super().__init__() super().__init__()
self.moe: FusedMoEConfig = moe self.moe: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None self.moe_quant_config: FusedMoEQuantConfig | None = None
self.moe_mk: mk.FusedMoEModularKernel | None = None self.moe_kernel: mk.FusedMoEKernel | None = None
@property @property
def supports_internal_mk(self) -> bool: def supports_internal_mk(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for # NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface. # completed migration to the new internal MK interface.
return self.moe_mk is not None return self.moe_kernel is not None
@property @property
def mk_owns_shared_expert(self) -> bool: def mk_owns_shared_expert(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for # NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface. # completed migration to the new internal MK interface.
return self.moe_mk is not None and self.moe_mk.shared_experts is not None return (
self.moe_kernel is not None and self.moe_kernel.shared_experts is not None
)
@abstractmethod @abstractmethod
def create_weights( def create_weights(
...@@ -66,35 +68,25 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -66,35 +68,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None: ) -> FusedMoEPrepareAndFinalizeModular | None:
from .all2all_utils import maybe_make_prepare_finalize from .all2all_utils import maybe_make_prepare_finalize
return maybe_make_prepare_finalize( pf = maybe_make_prepare_finalize(
self.moe, self.moe_quant_config, routing_tables self.moe, self.moe_quant_config, routing_tables
) )
assert pf is None or isinstance(pf, FusedMoEPrepareAndFinalizeModular)
return pf
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEExpertsModular:
# based on the all2all implementation, select the appropriate # based on the all2all implementation, select the appropriate
# gemm implementation # gemm implementation
raise NotImplementedError( raise ValueError(
f"{self.__class__.__name__} must select appropriate gemm " f"{self.__class__.__name__} uses the new modular kernel initialization "
"implementation based on the prepare_finalize" "logic. This function should not be called."
)
def prepare_dp_allgather_tensor(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
raise NotImplementedError(
"Method 'prepare_dp_allgather_tensor' is not implemented in "
f"{self.__class__.__name__}."
) )
@abstractmethod @abstractmethod
...@@ -105,8 +97,8 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -105,8 +97,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@property @property
def topk_indices_dtype(self) -> torch.dtype | None: def topk_indices_dtype(self) -> torch.dtype | None:
if self.moe_mk is not None: if self.moe_kernel is not None:
return self.moe_mk.prepare_finalize.topk_indices_dtype() return self.moe_kernel.prepare_finalize.topk_indices_dtype()
return None return None
@property @property
...@@ -119,7 +111,12 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -119,7 +111,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@property @property
def is_monolithic(self) -> bool: def is_monolithic(self) -> bool:
return False if self.moe_kernel is None:
if hasattr(self, "experts_cls"):
return self.experts_cls.is_monolithic()
else:
return False
return self.moe_kernel.is_monolithic
def apply( def apply(
self, self,
......
...@@ -13,8 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( ...@@ -13,8 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase, FusedMoEMethodBase,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel, FusedMoEKernel,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalizeModular,
) )
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -26,15 +26,15 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -26,15 +26,15 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
# --8<-- [end:modular_fused_moe] # --8<-- [end:modular_fused_moe]
def __init__( def __init__(
self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel self, old_quant_method: FusedMoEMethodBase, moe_kernel: FusedMoEKernel
): ):
super().__init__(old_quant_method.moe) super().__init__(old_quant_method.moe)
self.moe_quant_config = old_quant_method.moe_quant_config self.moe_quant_config = old_quant_method.moe_quant_config
self.moe_mk = experts self.moe_kernel = moe_kernel
self.disable_expert_map = getattr( self.disable_expert_map = getattr(
old_quant_method, old_quant_method,
"disable_expert_map", "disable_expert_map",
not self.moe_mk.supports_expert_map(), not self.moe_kernel.supports_expert_map(),
) )
self.old_quant_method = old_quant_method self.old_quant_method = old_quant_method
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
...@@ -43,13 +43,13 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -43,13 +43,13 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def make( def make(
moe_layer: torch.nn.Module, moe_layer: torch.nn.Module,
old_quant_method: FusedMoEMethodBase, old_quant_method: FusedMoEMethodBase,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalizeModular,
shared_experts: torch.nn.Module | None, shared_experts: torch.nn.Module | None,
inplace: bool = False, inplace: bool = False,
) -> "FusedMoEModularMethod": ) -> "FusedMoEModularMethod":
return FusedMoEModularMethod( return FusedMoEModularMethod(
old_quant_method, old_quant_method,
FusedMoEModularKernel( FusedMoEKernel(
prepare_finalize, prepare_finalize,
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts, shared_experts,
...@@ -90,8 +90,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -90,8 +90,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None assert self.moe_kernel is not None
return self.moe_mk( return self.moe_kernel.apply(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
......
...@@ -511,7 +511,7 @@ def make_routing_data( ...@@ -511,7 +511,7 @@ def make_routing_data(
return routing_data, gather_indx, scatter_indx return routing_data, gather_indx, scatter_indx
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -12,7 +12,7 @@ from vllm.platforms import current_platform ...@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
""" """
Prepare/Finalize using MoRI kernels. Prepare/Finalize using MoRI kernels.
""" """
......
...@@ -18,13 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -18,13 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config, fp8_w8a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_fp8,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend, FlashinferMoeBackend,
get_flashinfer_moe_backend, get_flashinfer_moe_backend,
make_fp8_moe_alpha_scales_for_fi,
prepare_fp8_moe_layer_for_fi, prepare_fp8_moe_layer_for_fi,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...@@ -103,9 +99,13 @@ def _get_priority_backends( ...@@ -103,9 +99,13 @@ def _get_priority_backends(
def backend_to_kernel_cls( def backend_to_kernel_cls(
backend: Fp8MoeBackend, backend: Fp8MoeBackend,
) -> type[mk.FusedMoEPermuteExpertsUnpermute]: ) -> type[mk.FusedMoEExperts]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
raise NotImplementedError from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501
TrtLlmFp8Experts,
)
return TrtLlmFp8Experts
elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS: elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
...@@ -205,13 +205,11 @@ def select_fp8_moe_backend( ...@@ -205,13 +205,11 @@ def select_fp8_moe_backend(
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
allow_vllm_cutlass: bool = False, allow_vllm_cutlass: bool = False,
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts] | None]:
""" """
Select the primary FP8 MoE backend Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime. Note: Shape-specific fallbacks may still occur at runtime.
""" """
k_cls: type[mk.FusedMoEPermuteExpertsUnpermute] | None = None
if config.is_lora_enabled: if config.is_lora_enabled:
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON) return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
...@@ -252,7 +250,7 @@ def select_fp8_moe_backend( ...@@ -252,7 +250,7 @@ def select_fp8_moe_backend(
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat, activation_format: mk.FusedMoEActivationFormat,
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
k_cls = backend_to_kernel_cls(backend) k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config( supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format k_cls, config, weight_key, activation_key, activation_format
...@@ -287,16 +285,6 @@ def select_fp8_moe_backend( ...@@ -287,16 +285,6 @@ def select_fp8_moe_backend(
"vLLM CUTLASS FP8 MoE backend is disabled for this configuration." "vLLM CUTLASS FP8 MoE backend is disabled for this configuration."
) )
# Handle FLASHINFER_TRTLLM specially (no kernel class).
if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
supported, reason = is_supported_config_trtllm_fp8(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(requested_backend))
return requested_backend, None
raise ValueError(_make_log_unsupported(requested_backend, reason))
return _return_or_raise( return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format requested_backend, config, weight_key, activation_key, activation_format
) )
...@@ -311,51 +299,32 @@ def select_fp8_moe_backend( ...@@ -311,51 +299,32 @@ def select_fp8_moe_backend(
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it. # If user is explicit about backend, validate it.
fi_backend = get_flashinfer_moe_backend() fi_backend = get_flashinfer_moe_backend()
if fi_backend == FlashinferMoeBackend.CUTLASS:
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
supported, reason = is_supported_config_trtllm_fp8(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, None
else:
raise ValueError(_make_log_unsupported(backend, reason))
elif fi_backend == FlashinferMoeBackend.CUTLASS:
backend = Fp8MoeBackend.FLASHINFER_CUTLASS backend = Fp8MoeBackend.FLASHINFER_CUTLASS
return _return_or_raise( elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend, config, weight_key, activation_key, activation_format backend = Fp8MoeBackend.FLASHINFER_TRTLLM
)
else: else:
assert fi_backend == FlashinferMoeBackend.CUTEDSL raise ValueError(
raise ValueError("FlashInfer MaskedGEMM not supported for FP8") f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE."
)
k_cls = backend_to_kernel_cls(backend)
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
else: else:
# If the user is not explicit about the backend, try both. # If the user is not explicit about the backend, try both.
for backend in [ for backend in [
Fp8MoeBackend.FLASHINFER_TRTLLM, Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.FLASHINFER_CUTLASS, Fp8MoeBackend.FLASHINFER_CUTLASS,
]: ]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: k_cls = backend_to_kernel_cls(backend)
k_cls = None supported, reason = k_cls.is_supported_config(
supported, reason = is_supported_config_trtllm_fp8( k_cls,
config, config,
weight_key, weight_key,
activation_key, activation_key,
activation_format, activation_format,
) )
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported: if supported:
logger.info_once(_make_log_backend(backend), scope="local") logger.info_once(_make_log_backend(backend), scope="local")
...@@ -408,23 +377,14 @@ def select_fp8_moe_backend( ...@@ -408,23 +377,14 @@ def select_fp8_moe_backend(
# Select kernels in order of backend. # Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS: for backend in AVAILABLE_BACKENDS:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: k_cls = backend_to_kernel_cls(backend)
k_cls = None supported, reason = k_cls.is_supported_config(
supported, reason = is_supported_config_trtllm_fp8( k_cls,
config, config,
weight_key, weight_key,
activation_key, activation_key,
activation_format, activation_format,
) )
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported: if supported:
logger.info_once(_make_log_backend(backend), scope="local") logger.info_once(_make_log_backend(backend), scope="local")
...@@ -510,7 +470,7 @@ def make_fp8_moe_quant_config( ...@@ -510,7 +470,7 @@ def make_fp8_moe_quant_config(
block_shape: list[int] | None = None, block_shape: list[int] | None = None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
per_out_ch_quant: bool = False, per_out_ch_quant: bool = False,
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig:
""" """
Create FusedMoEQuantConfig for the specified FP8 Backend. Create FusedMoEQuantConfig for the specified FP8 Backend.
The FusedMoEQuantConfig holds the scales that are used The FusedMoEQuantConfig holds the scales that are used
...@@ -523,9 +483,6 @@ def make_fp8_moe_quant_config( ...@@ -523,9 +483,6 @@ def make_fp8_moe_quant_config(
In a future PR, we will have this function should be In a future PR, we will have this function should be
a method of the modular kernel itself. a method of the modular kernel itself.
""" """
# TRTLLM does not use Modular Kernel abstraction yet.
if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
# MARLIN is mixed precision W8A16 config. # MARLIN is mixed precision W8A16 config.
if fp8_backend == Fp8MoeBackend.MARLIN: if fp8_backend == Fp8MoeBackend.MARLIN:
...@@ -539,12 +496,6 @@ def make_fp8_moe_quant_config( ...@@ -539,12 +496,6 @@ def make_fp8_moe_quant_config(
# (alpha = w_scale * a_scale) and inverse a2 scale. # (alpha = w_scale * a_scale) and inverse a2 scale.
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None: if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None:
assert a1_scale is not None and a2_scale is not None assert a1_scale is not None and a2_scale is not None
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w1_scale,
a1_scale,
w2_scale,
a2_scale,
)
return fp8_w8a8_moe_quant_config( return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
...@@ -552,8 +503,8 @@ def make_fp8_moe_quant_config( ...@@ -552,8 +503,8 @@ def make_fp8_moe_quant_config(
a2_scale=a2_scale, a2_scale=a2_scale,
a1_gscale=(1.0 / a1_scale), a1_gscale=(1.0 / a1_scale),
a2_gscale=(1.0 / a2_scale), a2_gscale=(1.0 / a2_scale),
g1_alphas=g1_alphas, g1_alphas=(w1_scale * a1_scale).squeeze(),
g2_alphas=g2_alphas, g2_alphas=(w2_scale * a2_scale).squeeze(),
) )
# All other backends use normal config. # All other backends use normal config.
return fp8_w8a8_moe_quant_config( return fp8_w8a8_moe_quant_config(
...@@ -570,17 +521,18 @@ def make_fp8_moe_quant_config( ...@@ -570,17 +521,18 @@ def make_fp8_moe_quant_config(
def make_fp8_moe_kernel( def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig, moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], experts_cls: type[mk.FusedMoEExperts],
fp8_backend: Fp8MoeBackend, fp8_backend: Fp8MoeBackend,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEKernel:
# Create Prepare/Finalize. # Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize( prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config, moe=moe_config,
quant_config=moe_quant_config, quant_config=moe_quant_config,
routing_tables=routing_tables, routing_tables=routing_tables,
allow_new_interface=True, allow_new_interface=True,
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
) )
assert prepare_finalize is not None assert prepare_finalize is not None
...@@ -605,7 +557,7 @@ def make_fp8_moe_kernel( ...@@ -605,7 +557,7 @@ def make_fp8_moe_kernel(
# NOTE(rob): we only want the mk to control the shared_expert # NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explicit in # if using all2all (for SBO). bnell is making this explicit in
# the new MoE runner class. # the new MoE runner class.
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=( shared_experts=(
......
...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_w4a16_moe_quant_config, nvfp4_w4a16_moe_quant_config,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass, prepare_nvfp4_moe_layer_for_fi_or_cutlass,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
...@@ -67,39 +66,46 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool: ...@@ -67,39 +66,46 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
def backend_to_kernel_cls( def backend_to_kernel_cls(
backend: NvFp4MoeBackend, backend: NvFp4MoeBackend,
) -> type[mk.FusedMoEPermuteExpertsUnpermute]: ) -> list[type[mk.FusedMoEExperts]]:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
raise NotImplementedError( from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
"FLASHINFER_TRTLLM doesn't support Modular Kernel Interface" TrtLlmNvFp4ExpertsModular,
TrtLlmNvFp4ExpertsMonolithic,
) )
# NOTE: prefer Monolthic > Modular, so return Monolithic first.
return [
TrtLlmNvFp4ExpertsMonolithic,
TrtLlmNvFp4ExpertsModular,
]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, FlashInferExperts,
) )
return FlashInferExperts return [FlashInferExperts]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
FlashInferCuteDSLExperts, FlashInferCuteDSLExperts,
) )
return FlashInferCuteDSLExperts return [FlashInferCuteDSLExperts]
elif backend == NvFp4MoeBackend.VLLM_CUTLASS: elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4, CutlassExpertsFp4,
) )
return CutlassExpertsFp4 return [CutlassExpertsFp4]
elif backend == NvFp4MoeBackend.MARLIN: elif backend == NvFp4MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts, MarlinExperts,
) )
return MarlinExperts return [MarlinExperts]
else: else:
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}") raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
...@@ -125,7 +131,7 @@ def select_nvfp4_moe_backend( ...@@ -125,7 +131,7 @@ def select_nvfp4_moe_backend(
config: FusedMoEConfig, config: FusedMoEConfig,
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
""" """
Select the primary NvFP4 MoE backend Select the primary NvFP4 MoE backend
Note: Shape-specific fallbacks may still occur at runtime. Note: Shape-specific fallbacks may still occur at runtime.
...@@ -175,29 +181,21 @@ def select_nvfp4_moe_backend( ...@@ -175,29 +181,21 @@ def select_nvfp4_moe_backend(
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat, activation_format: mk.FusedMoEActivationFormat,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
k_cls = backend_to_kernel_cls(backend) for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config( supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format k_cls, config, weight_key, activation_key, activation_format
) )
if supported: if supported:
logger.info_once(_make_log_backend(backend)) logger.info_once(_make_log_backend(backend))
return backend, k_cls return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason)) raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user. # Handle explicit moe_backend from user.
runner_backend = config.moe_backend runner_backend = config.moe_backend
if runner_backend != "auto": if runner_backend != "auto":
requested_backend = map_nvfp4_backend(runner_backend) requested_backend = map_nvfp4_backend(runner_backend)
if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
supported, reason = is_supported_config_trtllm(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(requested_backend))
return requested_backend, None
raise ValueError(_make_log_unsupported(requested_backend, reason))
return _return_or_raise( return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format requested_backend, config, weight_key, activation_key, activation_format
) )
...@@ -210,36 +208,14 @@ def select_nvfp4_moe_backend( ...@@ -210,36 +208,14 @@ def select_nvfp4_moe_backend(
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it. # If user is explicit about backend, validate it.
fi_backend = get_flashinfer_moe_backend() backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
return _return_or_raise(
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: backend, config, weight_key, activation_key, activation_format
backend = NvFp4MoeBackend.FLASHINFER_TRTLLM )
supported, reason = is_supported_config_trtllm(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, None
else:
raise ValueError(_make_log_unsupported(backend, reason))
else:
backend = fi_2_vllm_backend_map[fi_backend]
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
else: else:
# If the user is not explicit about the backend, try each. # If the user is not explicit about the backend, try each.
for backend in FLASHINFER_NVFP4_MOE_BACKENDS: for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: for k_cls in backend_to_kernel_cls(backend):
k_cls = None
supported, reason = is_supported_config_trtllm(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config( supported, reason = k_cls.is_supported_config(
k_cls, k_cls,
config, config,
...@@ -247,13 +223,13 @@ def select_nvfp4_moe_backend( ...@@ -247,13 +223,13 @@ def select_nvfp4_moe_backend(
activation_key, activation_key,
activation_format, activation_format,
) )
if supported: if supported:
logger.info_once(_make_log_backend(backend), scope="local") logger.info_once(_make_log_backend(backend), scope="local")
return backend, None return backend, k_cls
else: else:
logger.debug_once( logger.debug_once(
_make_log_unsupported(backend, reason), scope="local" _make_log_unsupported(backend, reason), scope="local"
) )
raise NotImplementedError( raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no " "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
...@@ -268,16 +244,7 @@ def select_nvfp4_moe_backend( ...@@ -268,16 +244,7 @@ def select_nvfp4_moe_backend(
# Select kernels in order of backend. # Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS: for backend in AVAILABLE_BACKENDS:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: for k_cls in backend_to_kernel_cls(backend):
k_cls = None # type: ignore[assignment]
supported, reason = is_supported_config_trtllm(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config( supported, reason = k_cls.is_supported_config(
k_cls, k_cls,
config, config,
...@@ -286,11 +253,11 @@ def select_nvfp4_moe_backend( ...@@ -286,11 +253,11 @@ def select_nvfp4_moe_backend(
activation_format, activation_format,
) )
if supported: if supported:
logger.info_once(_make_log_backend(backend), scope="local") logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls return backend, k_cls
else: else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local") logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError( raise NotImplementedError(
"No NvFp4 MoE backend supports the deployment configuration." "No NvFp4 MoE backend supports the deployment configuration."
...@@ -398,12 +365,8 @@ def make_nvfp4_moe_quant_config( ...@@ -398,12 +365,8 @@ def make_nvfp4_moe_quant_config(
w2_scale_2: torch.Tensor, w2_scale_2: torch.Tensor,
a13_scale: torch.Tensor, a13_scale: torch.Tensor,
a2_scale: torch.Tensor, a2_scale: torch.Tensor,
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig:
UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM] if backend == NvFp4MoeBackend.MARLIN:
if backend in UNSUPPORTED:
return None
elif backend == NvFp4MoeBackend.MARLIN:
return nvfp4_w4a16_moe_quant_config( return nvfp4_w4a16_moe_quant_config(
g1_alphas=w13_scale_2, g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2, g2_alphas=w2_scale_2,
...@@ -420,22 +383,27 @@ def make_nvfp4_moe_quant_config( ...@@ -420,22 +383,27 @@ def make_nvfp4_moe_quant_config(
a2_gscale=(1.0 / a2_scale), a2_gscale=(1.0 / a2_scale),
w1_scale=w13_scale, w1_scale=w13_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
# NOTE(rob): this is a hack until the MoE kernels
# create their own quant configs. TRTLLM kernel
# does not accept swizzled input quant scales.
is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM),
) )
def make_nvfp4_moe_kernel( def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig, moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], experts_cls: type[mk.FusedMoEExperts],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEKernel:
# Create Prepare/Finalize. # Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize( prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config, moe=moe_config,
quant_config=moe_quant_config, quant_config=moe_quant_config,
routing_tables=routing_tables, routing_tables=routing_tables,
allow_new_interface=True, allow_new_interface=True,
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
) )
assert prepare_finalize is not None assert prepare_finalize is not None
...@@ -460,7 +428,7 @@ def make_nvfp4_moe_kernel( ...@@ -460,7 +428,7 @@ def make_nvfp4_moe_kernel(
# NOTE(rob): we only want the mk to control the shared_expert # NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explicit in # if using all2all (for SBO). bnell is making this explicit in
# the new MoE runner class. # the new MoE runner class.
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=( shared_experts=(
......
...@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( ...@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_bf16, is_supported_config_trtllm_bf16,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoDPEPModular,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31, swap_w13_to_w31,
...@@ -209,7 +209,7 @@ def make_unquantized_moe_kernel( ...@@ -209,7 +209,7 @@ def make_unquantized_moe_kernel(
backend: UnquantizedMoeBackend, backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
) -> mk.FusedMoEModularKernel | None: ) -> mk.FusedMoEKernel | None:
if backend in UNSUPPORTED_BACKEND: if backend in UNSUPPORTED_BACKEND:
return None return None
...@@ -218,8 +218,8 @@ def make_unquantized_moe_kernel( ...@@ -218,8 +218,8 @@ def make_unquantized_moe_kernel(
FlashInferExperts, FlashInferExperts,
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoDPEPModular(),
FlashInferExperts( FlashInferExperts(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
...@@ -232,8 +232,8 @@ def make_unquantized_moe_kernel( ...@@ -232,8 +232,8 @@ def make_unquantized_moe_kernel(
AiterExperts, AiterExperts,
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoDPEPModular(),
AiterExperts( AiterExperts(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
...@@ -243,8 +243,8 @@ def make_unquantized_moe_kernel( ...@@ -243,8 +243,8 @@ def make_unquantized_moe_kernel(
elif backend == UnquantizedMoeBackend.TRITON: elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe import TritonExperts
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoDPEPModular(),
TritonExperts( TritonExperts(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
...@@ -254,8 +254,8 @@ def make_unquantized_moe_kernel( ...@@ -254,8 +254,8 @@ def make_unquantized_moe_kernel(
elif backend == UnquantizedMoeBackend.XPU: elif backend == UnquantizedMoeBackend.XPU:
from vllm.model_executor.layers.fused_moe import XPUExperts from vllm.model_executor.layers.fused_moe import XPUExperts
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoDPEPModular(),
XPUExperts( XPUExperts(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.fused_moe.prepare_finalize.naive_dp_ep import (
MoEPrepareAndFinalizeNaiveDPEPModular,
MoEPrepareAndFinalizeNaiveDPEPMonolithic,
make_moe_prepare_and_finalize_naive_dp_ep,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize.no_dp_ep import (
MoEPrepareAndFinalizeNoDPEPModular,
MoEPrepareAndFinalizeNoDPEPMonolithic,
make_moe_prepare_and_finalize_no_dp_ep,
)
__all__ = [
"MoEPrepareAndFinalizeNaiveDPEPMonolithic",
"MoEPrepareAndFinalizeNaiveDPEPModular",
"make_moe_prepare_and_finalize_naive_dp_ep",
"MoEPrepareAndFinalizeNoDPEPMonolithic",
"MoEPrepareAndFinalizeNoDPEPModular",
"make_moe_prepare_and_finalize_no_dp_ep",
]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
...@@ -14,7 +13,68 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input ...@@ -14,7 +13,68 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave from vllm.utils.flashinfer import nvfp4_block_scale_interleave
class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): def _quantize_and_setup_dispatch(
a1: torch.Tensor,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[torch.Tensor, list[torch.Tensor] | None]:
# Defer input quantization to the MoE kernel.
if defer_input_quant:
a1q = a1
a1q_scale = None
else:
input_sf = (
quant_config.a1_gscale
if quant_config.use_nvfp4_w4a4
else quant_config.a1_scale
)
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
a1q, a1q_scale = a1q, a1q_scale = moe_kernel_quantize_input(
a1,
input_sf,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
is_fp4_scale_swizzled=False,
)
# Skip gathering scales if we have static quantization
# (the scale is a scalar, replicated on all ranks) or
# if quantization is deferred.
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
scales = None if skip_gather_scales else [a1q_scale]
return a1q, scales
def _unwrap_scale_and_prepare_for_moe(
scales: list[torch.Tensor] | None,
quant_config: FusedMoEQuantConfig,
) -> torch.Tensor:
assert scales is not None and len(scales) == 1
a1q_scale = scales[0]
# Apply swizzling after a2a if the MoE kernel needs it.
if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
assert a1q_scale is not None
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q_scale
class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
"""
Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
Uses Torch AR/RS or AR for dispatch/combine operations, applied
to the topk weights and ids.
"""
def __init__( def __init__(
self, self,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
...@@ -51,6 +111,8 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): ...@@ -51,6 +111,8 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False, defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
"""Quantize and Dispatch Topk Weights and Topk Ids."""
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
assert topk == 1, ( assert topk == 1, (
...@@ -59,30 +121,7 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): ...@@ -59,30 +121,7 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
# Note: do not use inplace for shared experts overlap # Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quantization to the MoE kernel. a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
use_nvfp4 = quant_config.use_nvfp4_w4a4
if defer_input_quant:
a1q = a1
a1q_scale = None
else:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled=False,
)
# Skip gathering scales if we have static quantization
# (the scale is a scalar, replicated on all ranks) or
# if quantization is deferred.
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
scales = None if skip_gather_scales else [a1q_scale]
res = get_ep_group().dispatch( res = get_ep_group().dispatch(
a1q, a1q,
...@@ -91,17 +130,13 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): ...@@ -91,17 +130,13 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
extra_tensors=scales, extra_tensors=scales,
) )
if skip_gather_scales:
if scales is None:
a1q, topk_weights, topk_ids = res a1q, topk_weights, topk_ids = res
a1q_scale = None
else: else:
a1q, topk_weights, topk_ids, scales = res a1q, topk_weights, topk_ids, scales = res
assert scales is not None and len(scales) == 1 a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
a1q_scale = scales[0]
if quant_config.quant_dtype == "nvfp4":
assert a1q_scale is not None
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights return a1q, a1q_scale, None, topk_ids, topk_weights
...@@ -130,8 +165,22 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): ...@@ -130,8 +165,22 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
) )
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
"""MoE prepare and finalize without expert parallelism.""" """
Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
Uses Torch AR/RS or AR for dispatch/combine operations, applied
to the router logits (the MoE kernel runs the router internally).
"""
def __init__(
self,
is_sequence_parallel: bool = False,
num_dispatchers: int = 1,
) -> None:
super().__init__()
self.is_sequence_parallel = is_sequence_parallel
self._num_dispatchers = num_dispatchers
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
...@@ -144,7 +193,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -144,7 +193,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
return None return None
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return 1 return self._num_dispatchers
def output_is_reduced(self) -> bool: def output_is_reduced(self) -> bool:
return False return False
...@@ -152,58 +201,53 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -152,58 +201,53 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
topk_weights: torch.Tensor, router_logits: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False, defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareMonolithicResultType:
if apply_router_weight_on_input: """Quantize and Dispatch Router Logits."""
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quant to moe kernel for backends (e.g. AITER, FI) a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
# which use a single kernel call for quant + experts.
if defer_input_quant:
return a1, None, None, None, None
input_sf = ( res = get_ep_group().dispatch_router_logits(
quant_config.a1_gscale a1q,
if quant_config.use_nvfp4_w4a4 router_logits,
else quant_config.a1_scale is_sequence_parallel=self.is_sequence_parallel,
) extra_tensors=scales,
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
input_sf,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
) )
return a1q, a1q_scale, None, None, None if scales is None:
a1q, router_logits = res
a1q_scale = None
else:
a1q, router_logits, scales = res
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
return a1q, a1q_scale, router_logits
def finalize( def finalize(
self, self,
output: torch.Tensor,
fused_expert_output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, ) -> torch.Tensor:
topk_ids: torch.Tensor, out = get_ep_group().combine(
apply_router_weight_on_input: bool, fused_expert_output, is_sequence_parallel=self.is_sequence_parallel
weight_and_reduce_impl: mk.TopKWeightAndReduce, )
) -> None: return out
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply( def make_moe_prepare_and_finalize_naive_dp_ep(
output=output, use_monolithic: bool,
fused_expert_output=fused_expert_output, is_sequence_parallel: bool = False,
topk_weights=topk_weights, num_dispatchers: int = 1,
topk_ids=topk_ids, ) -> MoEPrepareAndFinalizeNaiveDPEPModular | MoEPrepareAndFinalizeNaiveDPEPMonolithic:
apply_router_weight_on_input=apply_router_weight_on_input, return (
MoEPrepareAndFinalizeNaiveDPEPMonolithic(
is_sequence_parallel=is_sequence_parallel,
num_dispatchers=num_dispatchers,
)
if use_monolithic
else MoEPrepareAndFinalizeNaiveDPEPModular(
is_sequence_parallel=is_sequence_parallel,
num_dispatchers=num_dispatchers,
) )
)
...@@ -292,7 +292,7 @@ def rocm_aiter_fused_experts( ...@@ -292,7 +292,7 @@ def rocm_aiter_fused_experts(
) )
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): class AiterExperts(mk.FusedMoEExpertsModular):
@property @property
def expects_unquantized_inputs(self) -> bool: def expects_unquantized_inputs(self) -> bool:
return True return True
......
...@@ -64,7 +64,7 @@ if current_platform.is_cuda_alike(): ...@@ -64,7 +64,7 @@ if current_platform.is_cuda_alike():
# TODO(bowen): When using `FusedMoEModularKernel`, this # TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since # can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert # `FusedMoEPrepareAndFinalizeModular` will return the expert
# token count, in some cases directly from the kernel. # token count, in some cases directly from the kernel.
# However, now there are many code paths not using # However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`, # the modular kernel, e.g. calling `fused_experts`,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment