Unverified Commit 97995f63 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Create MK for TRTLLM Kernels (#32564)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
parent 881a6b01
...@@ -18,7 +18,7 @@ def get_local_sizes(): ...@@ -18,7 +18,7 @@ def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""Base class for FlashInfer MoE prepare and finalize operations.""" """Base class for FlashInfer MoE prepare and finalize operations."""
def __init__( def __init__(
...@@ -185,8 +185,8 @@ def flashinfer_alltoall_dispatch( ...@@ -185,8 +185,8 @@ def flashinfer_alltoall_dispatch(
ep_size, ep_size,
) )
# Swizzle after the A2A if nvfp4. # Swizzle after the A2A if MoE kernel expects swizzled scales.
if quant_config.quant_dtype == "nvfp4": if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
if x_sf.element_size() == 1: if x_sf.element_size() == 1:
x_sf = x_sf.view(torch.uint8) x_sf = x_sf.view(torch.uint8)
x_sf = nvfp4_block_scale_interleave(x_sf) x_sf = nvfp4_block_scale_interleave(x_sf)
......
...@@ -30,7 +30,7 @@ from vllm.utils.flashinfer import ( ...@@ -30,7 +30,7 @@ from vllm.utils.flashinfer import (
logger = init_logger(__name__) logger = init_logger(__name__)
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
......
...@@ -60,7 +60,7 @@ def is_valid_flashinfer_cutlass_fused_moe( ...@@ -60,7 +60,7 @@ def is_valid_flashinfer_cutlass_fused_moe(
return True return True
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): class FlashInferExperts(mk.FusedMoEExpertsModular):
def __init__( def __init__(
self, self,
moe_config: mk.FusedMoEConfig, moe_config: mk.FusedMoEConfig,
......
...@@ -10,16 +10,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -10,16 +10,6 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig, FusedMoEParallelConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
...@@ -39,49 +29,10 @@ def _supports_no_act_and_mul() -> bool: ...@@ -39,49 +29,10 @@ def _supports_no_act_and_mul() -> bool:
return True return True
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor and Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: MoEActivation) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
def _supports_routing_method(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
routing_method: RoutingMethodType,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
# NOTE(dbari): Default is not implemented and should not be enabled until it is
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
else:
raise ValueError("Unsupported quantization scheme.")
def _supports_routing_method_bf16( def _supports_routing_method_bf16(
routing_method: RoutingMethodType, routing_method: RoutingMethodType,
) -> bool: ) -> bool:
...@@ -99,62 +50,6 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo ...@@ -99,62 +50,6 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
return not moe_parallel_config.enable_eplb return not moe_parallel_config.enable_eplb
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def is_supported_config_trtllm_fp8(
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason(f"current device {current_platform.device_name}")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_quant_scheme(weight_key, activation_key):
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
elif not _supports_routing_method(
weight_key, activation_key, moe_config.routing_method
):
return False, _make_reason(f"routing method {moe_config.routing_method}")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason(f"activation format {activation_format}")
elif not _supports_router_logits_dtype(
moe_config.router_logits_dtype, moe_config.routing_method
):
return False, _make_reason(
"float32 router_logits with non-DeepSeekV3 routing "
f"{moe_config.router_logits_dtype}x{moe_config.routing_method}"
)
return True, None
def is_supported_config_trtllm_bf16( def is_supported_config_trtllm_bf16(
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
activation_format: mk.FusedMoEActivationFormat, activation_format: mk.FusedMoEActivationFormat,
...@@ -183,199 +78,6 @@ def is_supported_config_trtllm_bf16( ...@@ -183,199 +78,6 @@ def is_supported_config_trtllm_bf16(
return True, None return True, None
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routing_method_type: int,
routed_scaling: float | None = 1.0,
) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
assert top_k <= global_num_experts
assert top_k <= 10
assert global_num_experts % 4 == 0
assert block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 512
# The DeepSeekV3 routing method requires float32 router logits.
if routing_method_type == RoutingMethodType.DeepSeekV3:
routing_logits = routing_logits.to(torch.float32)
if routing_bias is not None:
routing_bias = routing_bias.to(x.dtype)
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
return flashinfer_trtllm_fp8_block_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale_inv,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale_inv,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
routing_method_type=routing_method_type,
use_shuffled_weight=False,
)
def flashinfer_fused_moe_blockscale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routing_method_type: int,
routed_scaling: float = 1.0,
) -> torch.Tensor:
return torch.empty_like(x)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
def fi_trtllm_fp8_per_tensor_moe(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
activation_type: int,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
quant_hidden_states, _ = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False,
)
from flashinfer.fused_moe.core import ActivationType
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe
# The DeepSeekV3 routing method requires float32 router logits.
if routing_method_type == RoutingMethodType.DeepSeekV3:
routing_logits = routing_logits.to(torch.float32)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=quant_hidden_states,
gemm1_weights=gemm1_weights,
output1_scales_scalar=output1_scales_scalar,
output1_scales_gate_scalar=output1_scales_gate_scalar,
gemm2_weights=gemm2_weights,
output2_scales_scalar=output2_scales_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
routing_method_type=routing_method_type,
# TODO: enum type Required for flashinfer==0.6.3, remove with update
# https://github.com/flashinfer-ai/flashinfer/pull/2508
activation_type=ActivationType(activation_type),
)
def fi_trtllm_fp8_per_tensor_moe_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
activation_type: int,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="fi_trtllm_fp8_per_tensor_moe",
op_func=fi_trtllm_fp8_per_tensor_moe,
mutates_args=["hidden_states"],
fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
def flashinfer_fused_moe_bf16( def flashinfer_fused_moe_bf16(
routing_logits: torch.Tensor, routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None, routing_bias: torch.Tensor | None,
......
...@@ -489,7 +489,7 @@ def invoke_moe_batched_triton_kernel( ...@@ -489,7 +489,7 @@ def invoke_moe_batched_triton_kernel(
) )
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
""" """
A reference prepare/finalize class that reorganizes the tokens into A reference prepare/finalize class that reorganizes the tokens into
expert batched format, i.e. E x max_num_tokens x K. This is the format expert batched format, i.e. E x max_num_tokens x K. This is the format
...@@ -645,7 +645,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -645,7 +645,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): class NaiveBatchedExperts(mk.FusedMoEExpertsModular):
""" """
A reference MoE expert class that operates on expert batched format, A reference MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the batched i.e. E x max_num_tokens x K. This is the format that the batched
...@@ -877,7 +877,7 @@ def batched_moe_kernel_quantize_input( ...@@ -877,7 +877,7 @@ def batched_moe_kernel_quantize_input(
return A_q, A_q_scale return A_q, A_q_scale
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class BatchedTritonExperts(mk.FusedMoEExpertsModular):
""" """
A Triton based MoE expert class that operates on expert batched format, A Triton based MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the batched i.e. E x max_num_tokens x K. This is the format that the batched
......
...@@ -526,7 +526,7 @@ def batched_fused_marlin_moe( ...@@ -526,7 +526,7 @@ def batched_fused_marlin_moe(
return output return output
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): class MarlinExpertsBase(mk.FusedMoEExpertsModular):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
......
...@@ -1736,7 +1736,7 @@ def fused_experts_impl( ...@@ -1736,7 +1736,7 @@ def fused_experts_impl(
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
# This needs separate memory since it's used concurrently with cache1 # This needs separate memory since it's used concurrently with cache1
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation( activation_out_dim = mk.FusedMoEExpertsModular.adjust_N_for_activation(
N, activation_enum N, activation_enum
) )
intermediate_cache2 = torch.empty( intermediate_cache2 = torch.empty(
...@@ -1924,7 +1924,7 @@ def fused_experts_impl( ...@@ -1924,7 +1924,7 @@ def fused_experts_impl(
return out_hidden_states return out_hidden_states
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class TritonExperts(mk.FusedMoEExpertsModular):
"""Triton-based fused MoE expert implementation.""" """Triton-based fused MoE expert implementation."""
def __init__( def __init__(
......
...@@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute, FusedMoEExpertsModular,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalizeModular,
) )
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
...@@ -27,19 +27,21 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -27,19 +27,21 @@ class FusedMoEMethodBase(QuantizeMethodBase):
super().__init__() super().__init__()
self.moe: FusedMoEConfig = moe self.moe: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None self.moe_quant_config: FusedMoEQuantConfig | None = None
self.moe_mk: mk.FusedMoEModularKernel | None = None self.moe_kernel: mk.FusedMoEKernel | None = None
@property @property
def supports_internal_mk(self) -> bool: def supports_internal_mk(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for # NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface. # completed migration to the new internal MK interface.
return self.moe_mk is not None return self.moe_kernel is not None
@property @property
def mk_owns_shared_expert(self) -> bool: def mk_owns_shared_expert(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for # NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface. # completed migration to the new internal MK interface.
return self.moe_mk is not None and self.moe_mk.shared_experts is not None return (
self.moe_kernel is not None and self.moe_kernel.shared_experts is not None
)
@abstractmethod @abstractmethod
def create_weights( def create_weights(
...@@ -66,35 +68,25 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -66,35 +68,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None: ) -> FusedMoEPrepareAndFinalizeModular | None:
from .all2all_utils import maybe_make_prepare_finalize from .all2all_utils import maybe_make_prepare_finalize
return maybe_make_prepare_finalize( pf = maybe_make_prepare_finalize(
self.moe, self.moe_quant_config, routing_tables self.moe, self.moe_quant_config, routing_tables
) )
assert pf is None or isinstance(pf, FusedMoEPrepareAndFinalizeModular)
return pf
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEExpertsModular:
# based on the all2all implementation, select the appropriate # based on the all2all implementation, select the appropriate
# gemm implementation # gemm implementation
raise NotImplementedError( raise ValueError(
f"{self.__class__.__name__} must select appropriate gemm " f"{self.__class__.__name__} uses the new modular kernel initialization "
"implementation based on the prepare_finalize" "logic. This function should not be called."
)
def prepare_dp_allgather_tensor(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
raise NotImplementedError(
"Method 'prepare_dp_allgather_tensor' is not implemented in "
f"{self.__class__.__name__}."
) )
@abstractmethod @abstractmethod
...@@ -105,8 +97,8 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -105,8 +97,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@property @property
def topk_indices_dtype(self) -> torch.dtype | None: def topk_indices_dtype(self) -> torch.dtype | None:
if self.moe_mk is not None: if self.moe_kernel is not None:
return self.moe_mk.prepare_finalize.topk_indices_dtype() return self.moe_kernel.prepare_finalize.topk_indices_dtype()
return None return None
@property @property
...@@ -119,7 +111,12 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -119,7 +111,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@property @property
def is_monolithic(self) -> bool: def is_monolithic(self) -> bool:
return False if self.moe_kernel is None:
if hasattr(self, "experts_cls"):
return self.experts_cls.is_monolithic()
else:
return False
return self.moe_kernel.is_monolithic
def apply( def apply(
self, self,
......
...@@ -13,8 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( ...@@ -13,8 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase, FusedMoEMethodBase,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel, FusedMoEKernel,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalizeModular,
) )
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -26,15 +26,15 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -26,15 +26,15 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
# --8<-- [end:modular_fused_moe] # --8<-- [end:modular_fused_moe]
def __init__( def __init__(
self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel self, old_quant_method: FusedMoEMethodBase, moe_kernel: FusedMoEKernel
): ):
super().__init__(old_quant_method.moe) super().__init__(old_quant_method.moe)
self.moe_quant_config = old_quant_method.moe_quant_config self.moe_quant_config = old_quant_method.moe_quant_config
self.moe_mk = experts self.moe_kernel = moe_kernel
self.disable_expert_map = getattr( self.disable_expert_map = getattr(
old_quant_method, old_quant_method,
"disable_expert_map", "disable_expert_map",
not self.moe_mk.supports_expert_map(), not self.moe_kernel.supports_expert_map(),
) )
self.old_quant_method = old_quant_method self.old_quant_method = old_quant_method
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
...@@ -43,13 +43,13 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -43,13 +43,13 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def make( def make(
moe_layer: torch.nn.Module, moe_layer: torch.nn.Module,
old_quant_method: FusedMoEMethodBase, old_quant_method: FusedMoEMethodBase,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalizeModular,
shared_experts: torch.nn.Module | None, shared_experts: torch.nn.Module | None,
inplace: bool = False, inplace: bool = False,
) -> "FusedMoEModularMethod": ) -> "FusedMoEModularMethod":
return FusedMoEModularMethod( return FusedMoEModularMethod(
old_quant_method, old_quant_method,
FusedMoEModularKernel( FusedMoEKernel(
prepare_finalize, prepare_finalize,
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts, shared_experts,
...@@ -90,8 +90,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -90,8 +90,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None assert self.moe_kernel is not None
return self.moe_mk( return self.moe_kernel.apply(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
......
...@@ -511,7 +511,7 @@ def make_routing_data( ...@@ -511,7 +511,7 @@ def make_routing_data(
return routing_data, gather_indx, scatter_indx return routing_data, gather_indx, scatter_indx
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
...@@ -56,25 +57,25 @@ logger = init_logger(__name__) ...@@ -56,25 +57,25 @@ logger = init_logger(__name__)
# MoE kernel implementations. # MoE kernel implementations.
# #
# The following main classes are defined: # The following main classes are defined:
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE # * FusedMoEPrepareAndFinalizeModular - an abstract base class for preparation of MoE
# inputs (e.g. quantization, distribution) and finalization of Moe outputs. # inputs (e.g. quantization, distribution) and finalization of Moe outputs.
# The prepare method must take care of any needed quantization and the # The prepare method must take care of any needed quantization and the
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method, # finalize method, informed by the FusedMoEExpertsModular method,
# may apply weights and/or do the final reduction of the output. # may apply weights and/or do the final reduction of the output.
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused # * FusedMoEExpertsModular - an abstract base class for the main fused
# MoE operation, i.e matmul + act_mul + optionally quant + matmul. # MoE operation, i.e matmul + act_mul + optionally quant + matmul.
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do # Some FusedMoEExpertsModular implementations may choose to do
# the weight application and/or reduction. The class communicates this # the weight application and/or reduction. The class communicates this
# to [Finalize] via a TopKWeightAndReduce object. # to [Finalize] via a TopKWeightAndReduce object.
# * FusedMoEModularKernel - an interface class that combines a # * FusedMoEModularKernel - an interface class that combines a
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to # FusedMoEPrepareAndFinalizeModular and a FusedMoEExpertsModular to
# provide the standard fused MoE kernel interface. # provide the standard fused MoE kernel interface.
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen # * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed # by the FusedMoEExpertsModular implementation that is passed
# on to [Finalize]. # on to [Finalize].
# #
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single # [Quantize-Prepare] and [Finalize] functionality are bundled into a single
# class `FusedMoEPrepareAndFinalize` since they could use collective # class `FusedMoEPrepareAndFinalizeModular` since they could use collective
# communication mechanisms that need to be consistent. # communication mechanisms that need to be consistent.
# #
...@@ -155,25 +156,96 @@ PrepareResultType = tuple[ ...@@ -155,25 +156,96 @@ PrepareResultType = tuple[
torch.Tensor | None, torch.Tensor | None,
] ]
#
# PrepareResultType is a tuple of:
# - quantized + dispatched a.
# - quantized + dispatched a1_scales.
# - dispatched router logits.
#
# See `prepare_monolithic` method below.
#
PrepareMonolithicResultType = tuple[
torch.Tensor,
torch.Tensor | None,
torch.Tensor,
]
ReceiverType = Callable[[], PrepareResultType] ReceiverType = Callable[[], PrepareResultType]
################################################################################
# Prepare/Finalize
################################################################################
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPrepareAndFinalize(ABC):
""" """
An abstract base class for the [Quantize-Prepare] and [Finalize] steps An abstract base class for the [Quantize-Prepare] and [Finalize] steps
described above. described above.
There are two variants of this class:
* FusedMoEPrepareAndFinalizeModular - this operates on topk ids and weights
* FusedMoEPrepareAndFinalizeMonolithic - the operates on router_logits
""" """
def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"): def post_init_setup(self, fused_experts: "FusedMoEExperts"):
""" """
Initialize FusedMoEPrepareAndFinalize settings that depend on Initialize FusedMoEPrepareAndFinalizeModular settings that depend on
FusedMoEPermuteExpertsUnpermute experts object. FusedMoEExpertsModular experts object.
The FusedMoEPrepareAndFinalize implementations that have such The FusedMoEPrepareAndFinalizeModular implementations that have such
dependencies may choose to override this function. dependencies may choose to override this function.
""" """
return return
@property
@abstractmethod
def activation_format(self) -> FusedMoEActivationFormat:
"""
A property indicating the output format of the activations for the
'prepare' method.
"""
raise NotImplementedError
@abstractmethod
def topk_indices_dtype(self) -> torch.dtype | None:
"""
The PrepareFinalize All2All implementations generally constrain the
dtype of the topk_ids they support. This function returns the
required topk indices dtype so it can be respected.
Return None if there are no such restrictions.
"""
raise NotImplementedError
@abstractmethod
def max_num_tokens_per_rank(self) -> int | None:
"""
Some PrepareFinalize All2All implementations are batched. Meaning,
they can process only as set of tokens at a time. This
function returns the batch size i.e the maximum number of tokens
the implementation can process at a time.
Return None if there are no such restrictions.
"""
raise NotImplementedError
@abstractmethod
def num_dispatchers(self) -> int:
raise NotImplementedError
@abstractmethod
def output_is_reduced(self) -> bool:
"""
Indicates whether or not the output of finalize is reduced across all
ranks.
"""
raise NotImplementedError
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
"""
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
described above for the Modular case.
"""
@abstractmethod @abstractmethod
def prepare( def prepare(
self, self,
...@@ -198,7 +270,7 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -198,7 +270,7 @@ class FusedMoEPrepareAndFinalize(ABC):
activations, before quantization + dispatching. activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts. - quant_config: Quantization info provided by the fused experts.
- defer_input_quant: Runtime parameter indicating whether or not to - defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute defer input quantization to the FusedMoEExpertsModular
in cases where the compute kernel expects unquantized inputs in cases where the compute kernel expects unquantized inputs
Returns a tuple of: Returns a tuple of:
...@@ -245,7 +317,7 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -245,7 +317,7 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the - apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching. activations, before quantization + dispatching.
- defer_input_quant: Runtime parameter indicating whether or not to - defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute defer input quantization to the FusedMoEExpertsModular
in cases where the compute kernel expects unquantized inputs in cases where the compute kernel expects unquantized inputs
Returns a callback or a hook callback pair that when invoked waits for Returns a callback or a hook callback pair that when invoked waits for
...@@ -338,56 +410,58 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -338,56 +410,58 @@ class FusedMoEPrepareAndFinalize(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@property
@abstractmethod class FusedMoEPrepareAndFinalizeMonolithic(FusedMoEPrepareAndFinalize):
def activation_format(self) -> FusedMoEActivationFormat: """
""" An abstract base class for the [Quantize-Prepare] and [Finalize] steps
A property indicating the output format of the activations for the described above for the monolithic case.
'prepare' method. """
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def topk_indices_dtype(self) -> torch.dtype | None: def prepare(
self,
a1: torch.Tensor,
router_logits: torch.Tensor,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> PrepareMonolithicResultType:
""" """
The PrepareFinalize All2All implementations generally constrain the Optional method for subclasses compatible with monolithic
dtype of the topk_ids they support. This function returns the FusedMoEExpertsModular kernels.
required topk indices dtype so it can be respected.
Return None if there are no such restrictions. Perform any quantization (and/or) dispatching needed for this kernel.
- a1: The (unquantized) input to the MoE layer.
- quant_config: Quantization info provided by the fused experts.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEExpertsModular
Returns a tuple of:
- quantized + dispatched a.
- Optional quantized + dispatched a1_scales.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def max_num_tokens_per_rank(self) -> int | None: def finalize(self, fused_expert_output: torch.Tensor) -> torch.Tensor:
""" """
Some PrepareFinalize All2All implementations are batched. Meaning, Optional method for subclasses compatible with monolithic
they can process only as set of tokens at a time. This FusedMoEExpertsModular kernels.
function returns the batch size i.e the maximum number of tokens
the implementation can process at a time. Perform any combine plus apply weights and perform a reduction on the
Return None if there are no such restrictions. fused experts output.
- fused_expert_output: The unweighted, unreduced output of the fused
experts, it will have (M, topk, K) shape.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def num_dispatchers(self) -> int:
raise NotImplementedError
@abstractmethod ################################################################################
def output_is_reduced(self) -> bool: # Experts
""" ################################################################################
Indicates whether or not the output of finalize is reduced across all
ranks.
"""
raise NotImplementedError
# TODO: add supported activations method (return string) # TODO: add supported activations method (return string)
class FusedMoEPermuteExpertsUnpermute(ABC): class FusedMoEExperts(ABC):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above.
"""
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
...@@ -419,6 +493,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -419,6 +493,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
@staticmethod
def is_monolithic() -> bool:
raise NotImplementedError("Implemented by subclasses.")
@property @property
def expects_unquantized_inputs(self) -> bool: def expects_unquantized_inputs(self) -> bool:
""" """
...@@ -439,49 +517,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -439,49 +517,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = a1.size(-1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
# #
# Various helpers for registering support for various features. # Various helpers for registering support for various features.
# Used by the oracle to select a particular kernel for a deployment. # Used by the oracle to select a particular kernel for a deployment.
...@@ -489,7 +524,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -489,7 +524,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@staticmethod @staticmethod
def is_supported_config( def is_supported_config(
cls: type["FusedMoEPermuteExpertsUnpermute"], cls: type["FusedMoEExperts"],
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
...@@ -512,6 +547,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -512,6 +547,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
return False, _make_reason( return False, _make_reason(
f"parallel config {moe_config.moe_parallel_config}" f"parallel config {moe_config.moe_parallel_config}"
) )
elif not cls._supports_routing_method(
moe_config.routing_method, weight_key, activation_key
):
return False, _make_reason(f"routing method {moe_config.routing_method}")
elif not cls._supports_router_logits_dtype(
moe_config.router_logits_dtype,
moe_config.routing_method,
):
return False, _make_reason(
f"router logits dtype {moe_config.router_logits_dtype}"
)
elif not cls._supports_shape(moe_config.hidden_dim):
return False, _make_reason(
f"{moe_config.hidden_dim} hidden dim is not supported"
)
elif activation_format != cls.activation_format(): elif activation_format != cls.activation_format():
return False, _make_reason(f"{activation_format.value} activation format") return False, _make_reason(f"{activation_format.value} activation format")
return True, None return True, None
...@@ -554,10 +604,48 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -554,10 +604,48 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@abstractmethod @abstractmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
""" """
Whether the kernel supports deployment in expert parallel. Whether the kernel supports deployment in particular parallel config.
Can be overriden if a kernel does not support EP, SP or some other
configuration.
""" """
raise NotImplementedError raise NotImplementedError
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""
Whether the kernel supports a routing method (e.g. GroupedTopK).
Can be overriden by monolithic kernels that execute the router
in addition to the experts if certain routers are not supported.
"""
return True
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
Whether a kernel supports a particular dtype for router logits input.
Can be overriden by monolithic kernels that execute the router
in addition to the experts if certain dtypes are not supported.
"""
return True
@staticmethod
def _supports_shape(hidden_dim: int) -> bool:
"""
Whether a kernel supports a particular shape. Can be overridden if a kernel
has specific shape requirements.
"""
return True
# #
# Various helpers for accessing quantization parameters from the # Various helpers for accessing quantization parameters from the
# quant_config. # quant_config.
...@@ -654,6 +742,65 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -654,6 +742,65 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
""" """
return False return False
def enable_chunking(self):
return (
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
)
class FusedMoEExpertsModular(FusedMoEExperts):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above.
"""
@staticmethod
def is_monolithic() -> bool:
return False
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = a1.size(-1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
""" """
Workspace type: The dtype to use for the workspace tensors. Workspace type: The dtype to use for the workspace tensors.
...@@ -726,11 +873,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -726,11 +873,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
) -> None: ) -> None:
apply_moe_activation(activation, output, input) apply_moe_activation(activation, output, input)
def enable_chunking(self): @abstractmethod
return (
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
)
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
raise NotImplementedError raise NotImplementedError
...@@ -791,6 +934,67 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -791,6 +934,67 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
raise NotImplementedError raise NotImplementedError
class FusedMoEExpertsMonolithic(FusedMoEExperts):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above, but with the monolithic interface (accepts router logits
rather than topk ids and weights).
"""
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""
Whether the kernel supports a routing method (e.g. GroupedTopK).
Monolithic kernels should explicitly opt-in to support.
"""
raise NotImplementedError
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
Whether the kernel supports a dtype for router logits.
Modular kernels should opt-in to support.
"""
raise NotImplementedError
@staticmethod
def is_monolithic() -> bool:
return True
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
"""
Same as apply(), except uses router_logits as opposed
to the topk_ids and topk_weights. This is useful for kernels
with fused router and fused_experts (e.g. FLASHINFER_TRTLLM).
"""
raise NotImplementedError
def _slice_scales( def _slice_scales(
scales: torch.Tensor | None, start: int, end: int scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None: ) -> torch.Tensor | None:
...@@ -802,75 +1006,32 @@ def _slice_scales( ...@@ -802,75 +1006,32 @@ def _slice_scales(
return None return None
@final ################################################################################
class FusedMoEModularKernel(torch.nn.Module): # Kernel
""" ################################################################################
This class combines a FusedMoEPrepareAndFinalize instance and
a FusedMoEPermuteExpertsUnpermute to provide an interface that
is compatible with the `fused_experts` function in fused_moe.py.
It takes care of managing any required scratch space.
Note: Instances of this class should only be used for a single model
layer due to any layer specific state that may be used by the component
objects.
"""
@final
class FusedMoEKernelModularImpl:
def __init__( def __init__(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalizeModular,
fused_experts: FusedMoEPermuteExpertsUnpermute, fused_experts: FusedMoEExpertsModular,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None, moe_parallel_config: FusedMoEParallelConfig | None = None,
inplace: bool = False, inplace: bool = False,
): ):
super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
self.shared_experts = shared_experts self.shared_experts = shared_experts
self.moe_parallel_config = moe_parallel_config
self.inplace = inplace self.inplace = inplace
# prefer an explicit FusedMoEParallelConfig when available (from
# FusedMoE layers / tests).
# if not provided, assume this kernel is
# running in a non-DP+EP context
self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config
self.is_dp_ep = ( self.is_dp_ep = (
moe_parallel_config is not None moe_parallel_config is not None
and moe_parallel_config.dp_size > 1 and moe_parallel_config.dp_size > 1
and moe_parallel_config.use_ep and moe_parallel_config.use_ep
) )
self._post_init_setup()
assert (
prepare_finalize.activation_format == fused_experts.activation_format()
), (
f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_format()}"
)
def _post_init_setup(self):
"""
Resolve any leftover setup dependencies between self.prepare_finalize
and self.fused_experts here.
"""
self.prepare_finalize.post_init_setup(self.fused_experts)
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps.
"""
return self.fused_experts.supports_expert_map()
def output_is_reduced(self) -> bool:
"""
Indicates whether or not the output of fused MoE kernel
is reduced across all ranks.
"""
return self.prepare_finalize.output_is_reduced()
def _chunk_info(self, M: int) -> tuple[int, int]: def _chunk_info(self, M: int) -> tuple[int, int]:
""" """
Compute number of chunks and chunk size for given M. Compute number of chunks and chunk size for given M.
...@@ -919,7 +1080,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -919,7 +1080,7 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
# Force worst-case allocation in profiling run for # Force worst-case allocation in profiling run for
# "mk.FusedMoEModularKernel.Standard" formats where this is only bounded # "mk.FusedMoEKernel.Standard" formats where this is only bounded
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with # by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
# DP+EP due to the random token routing. # DP+EP due to the random token routing.
is_profile_run = ( is_profile_run = (
...@@ -1313,13 +1474,13 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1313,13 +1474,13 @@ class FusedMoEModularKernel(torch.nn.Module):
assert shared_output is not None assert shared_output is not None
return shared_output, output return shared_output, output
def forward( def apply(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
activation: MoEActivation = MoEActivation.SILU, activation: MoEActivation = MoEActivation.SILU,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
...@@ -1334,8 +1495,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1334,8 +1495,7 @@ class FusedMoEModularKernel(torch.nn.Module):
- hidden_states: (torch.Tensor): The input tensor to the MoE layer. - hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The topk weights applied at the end of - topk_weights (torch.Tensor): The topk weights applied at the end of the layer.
the layer.
- topk_ids (torch.Tensor): A map of row to expert id. - topk_ids (torch.Tensor): A map of row to expert id.
- activation (MoEActivation): The activation function to apply after the first - activation (MoEActivation): The activation function to apply after the first
MoE layer. MoE layer.
...@@ -1354,7 +1514,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1354,7 +1514,6 @@ class FusedMoEModularKernel(torch.nn.Module):
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
if self.inplace: if self.inplace:
assert self.shared_experts is None assert self.shared_experts is None
assert not disable_inplace() assert not disable_inplace()
...@@ -1400,3 +1559,206 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1400,3 +1559,206 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input, apply_router_weight_on_input,
shared_experts_input=shared_experts_input, shared_experts_input=shared_experts_input,
) )
@final
class FusedMoEKernelMonolithicImpl:
def __init__(
self,
prepare_finalize: FusedMoEPrepareAndFinalizeMonolithic,
fused_experts: FusedMoEExpertsMonolithic,
):
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
"""
Same as forward(), except uses router_logits as opposed
to the topk_ids and topk_weights. This is used for kernels
that have fused router + experts (e.g. FLASHINFER_TRTLLM).
"""
# TODO(rob): add inplace support.
a1q, a1q_scale, router_logits = self.prepare_finalize.prepare(
hidden_states,
router_logits=router_logits,
quant_config=self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
)
fused_out = self.fused_experts.apply(
hidden_states=a1q,
w1=w1,
w2=w2,
router_logits=router_logits,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
a1q_scale=a1q_scale,
# grouped topk + fused topk bias parameters
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
topk_group=topk_group,
)
output = self.prepare_finalize.finalize(fused_out)
return output
@final
class FusedMoEKernel:
def __init__(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEExperts,
shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
inplace: bool = False,
):
super().__init__()
self.shared_experts = shared_experts # NOTE: check if we can remove
# Initialize the implementation (monolithic or modular).
self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl
if isinstance(
prepare_finalize, FusedMoEPrepareAndFinalizeModular
) and isinstance(fused_experts, FusedMoEExpertsModular):
self.impl = FusedMoEKernelModularImpl(
prepare_finalize,
fused_experts,
shared_experts,
moe_parallel_config,
inplace,
)
elif isinstance(
prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic
) and isinstance(fused_experts, FusedMoEExpertsMonolithic):
assert shared_experts is None
assert not inplace
self.impl = FusedMoEKernelMonolithicImpl(
prepare_finalize,
fused_experts,
)
else:
raise ValueError(
"prepare_finalize and fused_experts must both be either monolithic "
f"or non-monolithic but got {prepare_finalize.__class__.__name__} "
f"and {fused_experts.__class__.__name__}"
)
self._post_init_setup()
@property
def is_monolithic(self) -> bool:
return isinstance(self.impl, FusedMoEKernelMonolithicImpl)
@property
def prepare_finalize(self) -> FusedMoEPrepareAndFinalize:
return self.impl.prepare_finalize
@property
def fused_experts(self) -> FusedMoEExperts:
return self.impl.fused_experts
def _post_init_setup(self):
"""
Resolve any leftover setup dependencies between self.prepare_finalize
and self.fused_experts here.
"""
self.prepare_finalize.post_init_setup(self.impl.fused_experts)
assert (
self.prepare_finalize.activation_format
== self.fused_experts.activation_format()
)
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps.
"""
return self.fused_experts.supports_expert_map()
def output_is_reduced(self) -> bool:
"""
Indicates whether or not the output of fused MoE kernel
is reduced across all ranks.
"""
return self.prepare_finalize.output_is_reduced()
def apply_monolithic(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
assert isinstance(self.impl, FusedMoEKernelMonolithicImpl)
return self.impl.apply(
hidden_states=hidden_states,
w1=w1,
w2=w2,
router_logits=router_logits,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
topk_group=topk_group,
)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
shared_experts_input: torch.Tensor | None = None,
) -> torch.Tensor:
assert isinstance(self.impl, FusedMoEKernelModularImpl)
return self.impl.apply(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
...@@ -12,7 +12,7 @@ from vllm.platforms import current_platform ...@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
""" """
Prepare/Finalize using MoRI kernels. Prepare/Finalize using MoRI kernels.
""" """
......
...@@ -18,13 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -18,13 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config, fp8_w8a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_fp8,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend, FlashinferMoeBackend,
get_flashinfer_moe_backend, get_flashinfer_moe_backend,
make_fp8_moe_alpha_scales_for_fi,
prepare_fp8_moe_layer_for_fi, prepare_fp8_moe_layer_for_fi,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...@@ -103,9 +99,13 @@ def _get_priority_backends( ...@@ -103,9 +99,13 @@ def _get_priority_backends(
def backend_to_kernel_cls( def backend_to_kernel_cls(
backend: Fp8MoeBackend, backend: Fp8MoeBackend,
) -> type[mk.FusedMoEPermuteExpertsUnpermute]: ) -> type[mk.FusedMoEExperts]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
raise NotImplementedError from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501
TrtLlmFp8Experts,
)
return TrtLlmFp8Experts
elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS: elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
...@@ -205,13 +205,11 @@ def select_fp8_moe_backend( ...@@ -205,13 +205,11 @@ def select_fp8_moe_backend(
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
allow_vllm_cutlass: bool = False, allow_vllm_cutlass: bool = False,
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts] | None]:
""" """
Select the primary FP8 MoE backend Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime. Note: Shape-specific fallbacks may still occur at runtime.
""" """
k_cls: type[mk.FusedMoEPermuteExpertsUnpermute] | None = None
if config.is_lora_enabled: if config.is_lora_enabled:
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON) return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
...@@ -252,7 +250,7 @@ def select_fp8_moe_backend( ...@@ -252,7 +250,7 @@ def select_fp8_moe_backend(
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat, activation_format: mk.FusedMoEActivationFormat,
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
k_cls = backend_to_kernel_cls(backend) k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config( supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format k_cls, config, weight_key, activation_key, activation_format
...@@ -287,16 +285,6 @@ def select_fp8_moe_backend( ...@@ -287,16 +285,6 @@ def select_fp8_moe_backend(
"vLLM CUTLASS FP8 MoE backend is disabled for this configuration." "vLLM CUTLASS FP8 MoE backend is disabled for this configuration."
) )
# Handle FLASHINFER_TRTLLM specially (no kernel class).
if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
supported, reason = is_supported_config_trtllm_fp8(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(requested_backend))
return requested_backend, None
raise ValueError(_make_log_unsupported(requested_backend, reason))
return _return_or_raise( return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format requested_backend, config, weight_key, activation_key, activation_format
) )
...@@ -311,51 +299,32 @@ def select_fp8_moe_backend( ...@@ -311,51 +299,32 @@ def select_fp8_moe_backend(
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it. # If user is explicit about backend, validate it.
fi_backend = get_flashinfer_moe_backend() fi_backend = get_flashinfer_moe_backend()
if fi_backend == FlashinferMoeBackend.CUTLASS:
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
supported, reason = is_supported_config_trtllm_fp8(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, None
else:
raise ValueError(_make_log_unsupported(backend, reason))
elif fi_backend == FlashinferMoeBackend.CUTLASS:
backend = Fp8MoeBackend.FLASHINFER_CUTLASS backend = Fp8MoeBackend.FLASHINFER_CUTLASS
return _return_or_raise( elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend, config, weight_key, activation_key, activation_format backend = Fp8MoeBackend.FLASHINFER_TRTLLM
)
else: else:
assert fi_backend == FlashinferMoeBackend.CUTEDSL raise ValueError(
raise ValueError("FlashInfer MaskedGEMM not supported for FP8") f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE."
)
k_cls = backend_to_kernel_cls(backend)
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
else: else:
# If the user is not explicit about the backend, try both. # If the user is not explicit about the backend, try both.
for backend in [ for backend in [
Fp8MoeBackend.FLASHINFER_TRTLLM, Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.FLASHINFER_CUTLASS, Fp8MoeBackend.FLASHINFER_CUTLASS,
]: ]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: k_cls = backend_to_kernel_cls(backend)
k_cls = None supported, reason = k_cls.is_supported_config(
supported, reason = is_supported_config_trtllm_fp8( k_cls,
config, config,
weight_key, weight_key,
activation_key, activation_key,
activation_format, activation_format,
) )
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported: if supported:
logger.info_once(_make_log_backend(backend), scope="local") logger.info_once(_make_log_backend(backend), scope="local")
...@@ -408,23 +377,14 @@ def select_fp8_moe_backend( ...@@ -408,23 +377,14 @@ def select_fp8_moe_backend(
# Select kernels in order of backend. # Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS: for backend in AVAILABLE_BACKENDS:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: k_cls = backend_to_kernel_cls(backend)
k_cls = None supported, reason = k_cls.is_supported_config(
supported, reason = is_supported_config_trtllm_fp8( k_cls,
config, config,
weight_key, weight_key,
activation_key, activation_key,
activation_format, activation_format,
) )
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported: if supported:
logger.info_once(_make_log_backend(backend), scope="local") logger.info_once(_make_log_backend(backend), scope="local")
...@@ -510,7 +470,7 @@ def make_fp8_moe_quant_config( ...@@ -510,7 +470,7 @@ def make_fp8_moe_quant_config(
block_shape: list[int] | None = None, block_shape: list[int] | None = None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
per_out_ch_quant: bool = False, per_out_ch_quant: bool = False,
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig:
""" """
Create FusedMoEQuantConfig for the specified FP8 Backend. Create FusedMoEQuantConfig for the specified FP8 Backend.
The FusedMoEQuantConfig holds the scales that are used The FusedMoEQuantConfig holds the scales that are used
...@@ -523,9 +483,6 @@ def make_fp8_moe_quant_config( ...@@ -523,9 +483,6 @@ def make_fp8_moe_quant_config(
In a future PR, we will have this function should be In a future PR, we will have this function should be
a method of the modular kernel itself. a method of the modular kernel itself.
""" """
# TRTLLM does not use Modular Kernel abstraction yet.
if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
# MARLIN is mixed precision W8A16 config. # MARLIN is mixed precision W8A16 config.
if fp8_backend == Fp8MoeBackend.MARLIN: if fp8_backend == Fp8MoeBackend.MARLIN:
...@@ -539,12 +496,6 @@ def make_fp8_moe_quant_config( ...@@ -539,12 +496,6 @@ def make_fp8_moe_quant_config(
# (alpha = w_scale * a_scale) and inverse a2 scale. # (alpha = w_scale * a_scale) and inverse a2 scale.
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None: if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None:
assert a1_scale is not None and a2_scale is not None assert a1_scale is not None and a2_scale is not None
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w1_scale,
a1_scale,
w2_scale,
a2_scale,
)
return fp8_w8a8_moe_quant_config( return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
...@@ -552,8 +503,8 @@ def make_fp8_moe_quant_config( ...@@ -552,8 +503,8 @@ def make_fp8_moe_quant_config(
a2_scale=a2_scale, a2_scale=a2_scale,
a1_gscale=(1.0 / a1_scale), a1_gscale=(1.0 / a1_scale),
a2_gscale=(1.0 / a2_scale), a2_gscale=(1.0 / a2_scale),
g1_alphas=g1_alphas, g1_alphas=(w1_scale * a1_scale).squeeze(),
g2_alphas=g2_alphas, g2_alphas=(w2_scale * a2_scale).squeeze(),
) )
# All other backends use normal config. # All other backends use normal config.
return fp8_w8a8_moe_quant_config( return fp8_w8a8_moe_quant_config(
...@@ -570,17 +521,18 @@ def make_fp8_moe_quant_config( ...@@ -570,17 +521,18 @@ def make_fp8_moe_quant_config(
def make_fp8_moe_kernel( def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig, moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], experts_cls: type[mk.FusedMoEExperts],
fp8_backend: Fp8MoeBackend, fp8_backend: Fp8MoeBackend,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEKernel:
# Create Prepare/Finalize. # Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize( prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config, moe=moe_config,
quant_config=moe_quant_config, quant_config=moe_quant_config,
routing_tables=routing_tables, routing_tables=routing_tables,
allow_new_interface=True, allow_new_interface=True,
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
) )
assert prepare_finalize is not None assert prepare_finalize is not None
...@@ -605,7 +557,7 @@ def make_fp8_moe_kernel( ...@@ -605,7 +557,7 @@ def make_fp8_moe_kernel(
# NOTE(rob): we only want the mk to control the shared_expert # NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explicit in # if using all2all (for SBO). bnell is making this explicit in
# the new MoE runner class. # the new MoE runner class.
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=( shared_experts=(
......
...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_w4a16_moe_quant_config, nvfp4_w4a16_moe_quant_config,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass, prepare_nvfp4_moe_layer_for_fi_or_cutlass,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
...@@ -67,39 +66,46 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool: ...@@ -67,39 +66,46 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
def backend_to_kernel_cls( def backend_to_kernel_cls(
backend: NvFp4MoeBackend, backend: NvFp4MoeBackend,
) -> type[mk.FusedMoEPermuteExpertsUnpermute]: ) -> list[type[mk.FusedMoEExperts]]:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
raise NotImplementedError( from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
"FLASHINFER_TRTLLM doesn't support Modular Kernel Interface" TrtLlmNvFp4ExpertsModular,
TrtLlmNvFp4ExpertsMonolithic,
) )
# NOTE: prefer Monolthic > Modular, so return Monolithic first.
return [
TrtLlmNvFp4ExpertsMonolithic,
TrtLlmNvFp4ExpertsModular,
]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, FlashInferExperts,
) )
return FlashInferExperts return [FlashInferExperts]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
FlashInferCuteDSLExperts, FlashInferCuteDSLExperts,
) )
return FlashInferCuteDSLExperts return [FlashInferCuteDSLExperts]
elif backend == NvFp4MoeBackend.VLLM_CUTLASS: elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4, CutlassExpertsFp4,
) )
return CutlassExpertsFp4 return [CutlassExpertsFp4]
elif backend == NvFp4MoeBackend.MARLIN: elif backend == NvFp4MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts, MarlinExperts,
) )
return MarlinExperts return [MarlinExperts]
else: else:
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}") raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
...@@ -125,7 +131,7 @@ def select_nvfp4_moe_backend( ...@@ -125,7 +131,7 @@ def select_nvfp4_moe_backend(
config: FusedMoEConfig, config: FusedMoEConfig,
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
""" """
Select the primary NvFP4 MoE backend Select the primary NvFP4 MoE backend
Note: Shape-specific fallbacks may still occur at runtime. Note: Shape-specific fallbacks may still occur at runtime.
...@@ -175,29 +181,21 @@ def select_nvfp4_moe_backend( ...@@ -175,29 +181,21 @@ def select_nvfp4_moe_backend(
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat, activation_format: mk.FusedMoEActivationFormat,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
k_cls = backend_to_kernel_cls(backend) for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config( supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format k_cls, config, weight_key, activation_key, activation_format
) )
if supported: if supported:
logger.info_once(_make_log_backend(backend)) logger.info_once(_make_log_backend(backend))
return backend, k_cls return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason)) raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user. # Handle explicit moe_backend from user.
runner_backend = config.moe_backend runner_backend = config.moe_backend
if runner_backend != "auto": if runner_backend != "auto":
requested_backend = map_nvfp4_backend(runner_backend) requested_backend = map_nvfp4_backend(runner_backend)
if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
supported, reason = is_supported_config_trtllm(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(requested_backend))
return requested_backend, None
raise ValueError(_make_log_unsupported(requested_backend, reason))
return _return_or_raise( return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format requested_backend, config, weight_key, activation_key, activation_format
) )
...@@ -210,36 +208,14 @@ def select_nvfp4_moe_backend( ...@@ -210,36 +208,14 @@ def select_nvfp4_moe_backend(
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it. # If user is explicit about backend, validate it.
fi_backend = get_flashinfer_moe_backend() backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
return _return_or_raise(
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: backend, config, weight_key, activation_key, activation_format
backend = NvFp4MoeBackend.FLASHINFER_TRTLLM )
supported, reason = is_supported_config_trtllm(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, None
else:
raise ValueError(_make_log_unsupported(backend, reason))
else:
backend = fi_2_vllm_backend_map[fi_backend]
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
else: else:
# If the user is not explicit about the backend, try each. # If the user is not explicit about the backend, try each.
for backend in FLASHINFER_NVFP4_MOE_BACKENDS: for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: for k_cls in backend_to_kernel_cls(backend):
k_cls = None
supported, reason = is_supported_config_trtllm(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config( supported, reason = k_cls.is_supported_config(
k_cls, k_cls,
config, config,
...@@ -247,13 +223,13 @@ def select_nvfp4_moe_backend( ...@@ -247,13 +223,13 @@ def select_nvfp4_moe_backend(
activation_key, activation_key,
activation_format, activation_format,
) )
if supported: if supported:
logger.info_once(_make_log_backend(backend), scope="local") logger.info_once(_make_log_backend(backend), scope="local")
return backend, None return backend, k_cls
else: else:
logger.debug_once( logger.debug_once(
_make_log_unsupported(backend, reason), scope="local" _make_log_unsupported(backend, reason), scope="local"
) )
raise NotImplementedError( raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no " "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
...@@ -268,16 +244,7 @@ def select_nvfp4_moe_backend( ...@@ -268,16 +244,7 @@ def select_nvfp4_moe_backend(
# Select kernels in order of backend. # Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS: for backend in AVAILABLE_BACKENDS:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: for k_cls in backend_to_kernel_cls(backend):
k_cls = None # type: ignore[assignment]
supported, reason = is_supported_config_trtllm(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config( supported, reason = k_cls.is_supported_config(
k_cls, k_cls,
config, config,
...@@ -286,11 +253,11 @@ def select_nvfp4_moe_backend( ...@@ -286,11 +253,11 @@ def select_nvfp4_moe_backend(
activation_format, activation_format,
) )
if supported: if supported:
logger.info_once(_make_log_backend(backend), scope="local") logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls return backend, k_cls
else: else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local") logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError( raise NotImplementedError(
"No NvFp4 MoE backend supports the deployment configuration." "No NvFp4 MoE backend supports the deployment configuration."
...@@ -398,12 +365,8 @@ def make_nvfp4_moe_quant_config( ...@@ -398,12 +365,8 @@ def make_nvfp4_moe_quant_config(
w2_scale_2: torch.Tensor, w2_scale_2: torch.Tensor,
a13_scale: torch.Tensor, a13_scale: torch.Tensor,
a2_scale: torch.Tensor, a2_scale: torch.Tensor,
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig:
UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM] if backend == NvFp4MoeBackend.MARLIN:
if backend in UNSUPPORTED:
return None
elif backend == NvFp4MoeBackend.MARLIN:
return nvfp4_w4a16_moe_quant_config( return nvfp4_w4a16_moe_quant_config(
g1_alphas=w13_scale_2, g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2, g2_alphas=w2_scale_2,
...@@ -420,22 +383,27 @@ def make_nvfp4_moe_quant_config( ...@@ -420,22 +383,27 @@ def make_nvfp4_moe_quant_config(
a2_gscale=(1.0 / a2_scale), a2_gscale=(1.0 / a2_scale),
w1_scale=w13_scale, w1_scale=w13_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
# NOTE(rob): this is a hack until the MoE kernels
# create their own quant configs. TRTLLM kernel
# does not accept swizzled input quant scales.
is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM),
) )
def make_nvfp4_moe_kernel( def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig, moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], experts_cls: type[mk.FusedMoEExperts],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEKernel:
# Create Prepare/Finalize. # Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize( prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config, moe=moe_config,
quant_config=moe_quant_config, quant_config=moe_quant_config,
routing_tables=routing_tables, routing_tables=routing_tables,
allow_new_interface=True, allow_new_interface=True,
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
) )
assert prepare_finalize is not None assert prepare_finalize is not None
...@@ -460,7 +428,7 @@ def make_nvfp4_moe_kernel( ...@@ -460,7 +428,7 @@ def make_nvfp4_moe_kernel(
# NOTE(rob): we only want the mk to control the shared_expert # NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explicit in # if using all2all (for SBO). bnell is making this explicit in
# the new MoE runner class. # the new MoE runner class.
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=( shared_experts=(
......
...@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( ...@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_bf16, is_supported_config_trtllm_bf16,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoDPEPModular,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31, swap_w13_to_w31,
...@@ -209,7 +209,7 @@ def make_unquantized_moe_kernel( ...@@ -209,7 +209,7 @@ def make_unquantized_moe_kernel(
backend: UnquantizedMoeBackend, backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
) -> mk.FusedMoEModularKernel | None: ) -> mk.FusedMoEKernel | None:
if backend in UNSUPPORTED_BACKEND: if backend in UNSUPPORTED_BACKEND:
return None return None
...@@ -218,8 +218,8 @@ def make_unquantized_moe_kernel( ...@@ -218,8 +218,8 @@ def make_unquantized_moe_kernel(
FlashInferExperts, FlashInferExperts,
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoDPEPModular(),
FlashInferExperts( FlashInferExperts(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
...@@ -232,8 +232,8 @@ def make_unquantized_moe_kernel( ...@@ -232,8 +232,8 @@ def make_unquantized_moe_kernel(
AiterExperts, AiterExperts,
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoDPEPModular(),
AiterExperts( AiterExperts(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
...@@ -243,8 +243,8 @@ def make_unquantized_moe_kernel( ...@@ -243,8 +243,8 @@ def make_unquantized_moe_kernel(
elif backend == UnquantizedMoeBackend.TRITON: elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe import TritonExperts
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoDPEPModular(),
TritonExperts( TritonExperts(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
...@@ -254,8 +254,8 @@ def make_unquantized_moe_kernel( ...@@ -254,8 +254,8 @@ def make_unquantized_moe_kernel(
elif backend == UnquantizedMoeBackend.XPU: elif backend == UnquantizedMoeBackend.XPU:
from vllm.model_executor.layers.fused_moe import XPUExperts from vllm.model_executor.layers.fused_moe import XPUExperts
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoDPEPModular(),
XPUExperts( XPUExperts(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.fused_moe.prepare_finalize.naive_dp_ep import (
MoEPrepareAndFinalizeNaiveDPEPModular,
MoEPrepareAndFinalizeNaiveDPEPMonolithic,
make_moe_prepare_and_finalize_naive_dp_ep,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize.no_dp_ep import (
MoEPrepareAndFinalizeNoDPEPModular,
MoEPrepareAndFinalizeNoDPEPMonolithic,
make_moe_prepare_and_finalize_no_dp_ep,
)
__all__ = [
"MoEPrepareAndFinalizeNaiveDPEPMonolithic",
"MoEPrepareAndFinalizeNaiveDPEPModular",
"make_moe_prepare_and_finalize_naive_dp_ep",
"MoEPrepareAndFinalizeNoDPEPMonolithic",
"MoEPrepareAndFinalizeNoDPEPModular",
"make_moe_prepare_and_finalize_no_dp_ep",
]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
...@@ -14,7 +13,68 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input ...@@ -14,7 +13,68 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave from vllm.utils.flashinfer import nvfp4_block_scale_interleave
class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): def _quantize_and_setup_dispatch(
a1: torch.Tensor,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[torch.Tensor, list[torch.Tensor] | None]:
# Defer input quantization to the MoE kernel.
if defer_input_quant:
a1q = a1
a1q_scale = None
else:
input_sf = (
quant_config.a1_gscale
if quant_config.use_nvfp4_w4a4
else quant_config.a1_scale
)
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
a1q, a1q_scale = a1q, a1q_scale = moe_kernel_quantize_input(
a1,
input_sf,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
is_fp4_scale_swizzled=False,
)
# Skip gathering scales if we have static quantization
# (the scale is a scalar, replicated on all ranks) or
# if quantization is deferred.
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
scales = None if skip_gather_scales else [a1q_scale]
return a1q, scales
def _unwrap_scale_and_prepare_for_moe(
scales: list[torch.Tensor] | None,
quant_config: FusedMoEQuantConfig,
) -> torch.Tensor:
assert scales is not None and len(scales) == 1
a1q_scale = scales[0]
# Apply swizzling after a2a if the MoE kernel needs it.
if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
assert a1q_scale is not None
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q_scale
class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
"""
Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
Uses Torch AR/RS or AR for dispatch/combine operations, applied
to the topk weights and ids.
"""
def __init__( def __init__(
self, self,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
...@@ -51,6 +111,8 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): ...@@ -51,6 +111,8 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False, defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
"""Quantize and Dispatch Topk Weights and Topk Ids."""
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
assert topk == 1, ( assert topk == 1, (
...@@ -59,30 +121,7 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): ...@@ -59,30 +121,7 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
# Note: do not use inplace for shared experts overlap # Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quantization to the MoE kernel. a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
use_nvfp4 = quant_config.use_nvfp4_w4a4
if defer_input_quant:
a1q = a1
a1q_scale = None
else:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled=False,
)
# Skip gathering scales if we have static quantization
# (the scale is a scalar, replicated on all ranks) or
# if quantization is deferred.
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
scales = None if skip_gather_scales else [a1q_scale]
res = get_ep_group().dispatch( res = get_ep_group().dispatch(
a1q, a1q,
...@@ -91,17 +130,13 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): ...@@ -91,17 +130,13 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
extra_tensors=scales, extra_tensors=scales,
) )
if skip_gather_scales:
if scales is None:
a1q, topk_weights, topk_ids = res a1q, topk_weights, topk_ids = res
a1q_scale = None
else: else:
a1q, topk_weights, topk_ids, scales = res a1q, topk_weights, topk_ids, scales = res
assert scales is not None and len(scales) == 1 a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
a1q_scale = scales[0]
if quant_config.quant_dtype == "nvfp4":
assert a1q_scale is not None
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights return a1q, a1q_scale, None, topk_ids, topk_weights
...@@ -130,8 +165,22 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): ...@@ -130,8 +165,22 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
) )
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
"""MoE prepare and finalize without expert parallelism.""" """
Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
Uses Torch AR/RS or AR for dispatch/combine operations, applied
to the router logits (the MoE kernel runs the router internally).
"""
def __init__(
self,
is_sequence_parallel: bool = False,
num_dispatchers: int = 1,
) -> None:
super().__init__()
self.is_sequence_parallel = is_sequence_parallel
self._num_dispatchers = num_dispatchers
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
...@@ -144,7 +193,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -144,7 +193,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
return None return None
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return 1 return self._num_dispatchers
def output_is_reduced(self) -> bool: def output_is_reduced(self) -> bool:
return False return False
...@@ -152,58 +201,53 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -152,58 +201,53 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
topk_weights: torch.Tensor, router_logits: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False, defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareMonolithicResultType:
if apply_router_weight_on_input: """Quantize and Dispatch Router Logits."""
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quant to moe kernel for backends (e.g. AITER, FI) a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
# which use a single kernel call for quant + experts.
if defer_input_quant:
return a1, None, None, None, None
input_sf = ( res = get_ep_group().dispatch_router_logits(
quant_config.a1_gscale a1q,
if quant_config.use_nvfp4_w4a4 router_logits,
else quant_config.a1_scale is_sequence_parallel=self.is_sequence_parallel,
) extra_tensors=scales,
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
input_sf,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
) )
return a1q, a1q_scale, None, None, None if scales is None:
a1q, router_logits = res
a1q_scale = None
else:
a1q, router_logits, scales = res
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
return a1q, a1q_scale, router_logits
def finalize( def finalize(
self, self,
output: torch.Tensor,
fused_expert_output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, ) -> torch.Tensor:
topk_ids: torch.Tensor, out = get_ep_group().combine(
apply_router_weight_on_input: bool, fused_expert_output, is_sequence_parallel=self.is_sequence_parallel
weight_and_reduce_impl: mk.TopKWeightAndReduce, )
) -> None: return out
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply( def make_moe_prepare_and_finalize_naive_dp_ep(
output=output, use_monolithic: bool,
fused_expert_output=fused_expert_output, is_sequence_parallel: bool = False,
topk_weights=topk_weights, num_dispatchers: int = 1,
topk_ids=topk_ids, ) -> MoEPrepareAndFinalizeNaiveDPEPModular | MoEPrepareAndFinalizeNaiveDPEPMonolithic:
apply_router_weight_on_input=apply_router_weight_on_input, return (
MoEPrepareAndFinalizeNaiveDPEPMonolithic(
is_sequence_parallel=is_sequence_parallel,
num_dispatchers=num_dispatchers,
)
if use_monolithic
else MoEPrepareAndFinalizeNaiveDPEPModular(
is_sequence_parallel=is_sequence_parallel,
num_dispatchers=num_dispatchers,
) )
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous,
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
def _quantize_input(
a1: torch.Tensor,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
# which use a single kernel call for quant + experts.
if defer_input_quant:
return a1, None
input_sf = (
quant_config.a1_gscale if quant_config.use_nvfp4_w4a4 else quant_config.a1_scale
)
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
input_sf,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
)
return a1q, a1q_scale
class MoEPrepareAndFinalizeNoDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return 1
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant)
return a1q, a1q_scale, None, None, None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply(
output=output,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
class MoEPrepareAndFinalizeNoDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return 1
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
router_logits: torch.Tensor,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareMonolithicResultType:
a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant)
return a1q, a1q_scale, router_logits
def finalize(
self,
fused_expert_output: torch.Tensor,
) -> torch.Tensor:
return fused_expert_output
def make_moe_prepare_and_finalize_no_dp_ep(
use_monolithic: bool,
) -> MoEPrepareAndFinalizeNoDPEPModular | MoEPrepareAndFinalizeNoDPEPMonolithic:
return (
MoEPrepareAndFinalizeNoDPEPMonolithic()
if use_monolithic
else MoEPrepareAndFinalizeNoDPEPModular()
)
...@@ -292,7 +292,7 @@ def rocm_aiter_fused_experts( ...@@ -292,7 +292,7 @@ def rocm_aiter_fused_experts(
) )
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): class AiterExperts(mk.FusedMoEExpertsModular):
@property @property
def expects_unquantized_inputs(self) -> bool: def expects_unquantized_inputs(self) -> bool:
return True return True
......
...@@ -64,7 +64,7 @@ if current_platform.is_cuda_alike(): ...@@ -64,7 +64,7 @@ if current_platform.is_cuda_alike():
# TODO(bowen): When using `FusedMoEModularKernel`, this # TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since # can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert # `FusedMoEPrepareAndFinalizeModular` will return the expert
# token count, in some cases directly from the kernel. # token count, in some cases directly from the kernel.
# However, now there are many code paths not using # However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`, # the modular kernel, e.g. calling `fused_experts`,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment