Unverified Commit 97995f63 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Create MK for TRTLLM Kernels (#32564)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
parent 881a6b01
...@@ -320,8 +320,8 @@ class DefaultMoERunner(MoERunner): ...@@ -320,8 +320,8 @@ class DefaultMoERunner(MoERunner):
""" """
assert self.quant_method is not None assert self.quant_method is not None
return ( return (
self.quant_method.moe_mk is not None self.quant_method.moe_kernel is not None
and self.quant_method.moe_mk.output_is_reduced() and self.quant_method.moe_kernel.output_is_reduced()
) )
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
...@@ -640,45 +640,6 @@ class DefaultMoERunner(MoERunner): ...@@ -640,45 +640,6 @@ class DefaultMoERunner(MoERunner):
) )
with sp_ctx: with sp_ctx:
extra_tensors = None
if do_naive_dispatch_combine:
post_quant_allgather = (
self.quant_method is not None
and self.moe_config.dp_size > 1
and self.moe_config.use_ep
and getattr(self.quant_method, "do_post_quant_allgather", False)
)
if post_quant_allgather:
hidden_states_to_dispatch, extra_tensors = (
self.quant_method.prepare_dp_allgather_tensor(
layer, hidden_states, router_logits
)
)
else:
hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch_router_logits(
hidden_states_to_dispatch,
router_logits,
self.moe_config.is_sequence_parallel,
extra_tensors=extra_tensors,
)
if extra_tensors is not None:
(
orig_hidden_states,
router_logits,
extra_tensors_combined,
) = dispatch_res
hidden_states_combined = (
orig_hidden_states,
extra_tensors_combined[0],
)
else:
hidden_states_combined, router_logits = dispatch_res
orig_hidden_states = hidden_states_combined
else:
orig_hidden_states = hidden_states
# Run shared experts before matrix multiply. # Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states. # because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream: if has_separate_shared_experts and not use_shared_experts_stream:
...@@ -688,6 +649,17 @@ class DefaultMoERunner(MoERunner): ...@@ -688,6 +649,17 @@ class DefaultMoERunner(MoERunner):
) )
shared_output = self.shared_experts(shared_input) shared_output = self.shared_experts(shared_input)
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
# router logits to all experts.
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
hidden_states,
router_logits,
self.moe_config.is_sequence_parallel,
)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For # NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe # simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstract to better support PCP. # we should modify All2AllManager abstract to better support PCP.
...@@ -701,31 +673,22 @@ class DefaultMoERunner(MoERunner): ...@@ -701,31 +673,22 @@ class DefaultMoERunner(MoERunner):
dim=0, dim=0,
) )
# TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
# Figure out nicer way to do this.
if do_naive_dispatch_combine:
x = hidden_states_combined
x_orig = orig_hidden_states
else:
x = hidden_states
x_orig = hidden_states
# Matrix multiply. # Matrix multiply.
if self.quant_method.is_monolithic: if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic( final_hidden_states = self.quant_method.apply_monolithic(
layer=layer, layer=layer,
x=x, x=hidden_states,
router_logits=router_logits, router_logits=router_logits,
) )
else: else:
topk_weights, topk_ids = self.router.select_experts( topk_weights, topk_ids = self.router.select_experts(
hidden_states=x_orig, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
) )
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=layer, layer=layer,
x=x, # The type signture of this is wrong due to the hack. x=hidden_states,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
shared_experts_input=shared_input, shared_experts_input=shared_input,
......
...@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce): class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
""" """
Useful in the case when some FusedMoEPermuteExpertsUnpermute Useful in the case when some FusedMoEExpertsModular
implementation does not perform weight application and reduction implementation does not perform weight application and reduction
but cannot address the needs of all the compatible PrepareAndFinalize but cannot address the needs of all the compatible PrepareAndFinalize
implementations. implementations.
...@@ -62,7 +62,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): ...@@ -62,7 +62,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
if output is None: if output is None:
return fused_expert_output return fused_expert_output
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output` # MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output`
# tensor. # tensor.
assert output.size() == fused_expert_output.size(), ( assert output.size() == fused_expert_output.size(), (
"output shape is expected to match the fused_expert_output shape. " "output shape is expected to match the fused_expert_output shape. "
......
...@@ -32,8 +32,8 @@ class TritonOrCutlassExperts(FallbackExperts): ...@@ -32,8 +32,8 @@ class TritonOrCutlassExperts(FallbackExperts):
@staticmethod @staticmethod
def get_clses() -> tuple[ def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute], type[mk.FusedMoEExpertsModular],
type[mk.FusedMoEPermuteExpertsUnpermute], type[mk.FusedMoEExpertsModular],
]: ]:
return (CutlassExpertsFp8, TritonExperts) return (CutlassExpertsFp8, TritonExperts)
...@@ -77,7 +77,7 @@ class TritonOrCutlassExperts(FallbackExperts): ...@@ -77,7 +77,7 @@ class TritonOrCutlassExperts(FallbackExperts):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEExpertsModular:
# Small batch fallback for sm100. # Small batch fallback for sm100.
if self.is_sm100 and hidden_states.shape[0] <= 8: if self.is_sm100 and hidden_states.shape[0] <= 8:
return self.fallback_experts return self.fallback_experts
......
...@@ -32,8 +32,8 @@ class TritonOrDeepGemmExperts(FallbackExperts): ...@@ -32,8 +32,8 @@ class TritonOrDeepGemmExperts(FallbackExperts):
@staticmethod @staticmethod
def get_clses() -> tuple[ def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute], type[mk.FusedMoEExpertsModular],
type[mk.FusedMoEPermuteExpertsUnpermute], type[mk.FusedMoEExpertsModular],
]: ]:
return (DeepGemmExperts, TritonExperts) return (DeepGemmExperts, TritonExperts)
...@@ -79,7 +79,7 @@ class TritonOrDeepGemmExperts(FallbackExperts): ...@@ -79,7 +79,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEExpertsModular:
if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2): if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
return self.experts return self.experts
else: else:
......
...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
) )
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
"""TensorRT-LLM-based fused MoE expert implementation.""" """TensorRT-LLM-based fused MoE expert implementation."""
def __init__( def __init__(
......
...@@ -24,8 +24,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( ...@@ -24,8 +24,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEActivationFormat,
FusedMoEPermuteExpertsUnpermute, FusedMoEExpertsModular,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalizeModular,
) )
from vllm.model_executor.layers.fused_moe.oracle.unquantized import ( from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
UnquantizedMoeBackend, UnquantizedMoeBackend,
...@@ -70,7 +70,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -70,7 +70,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.rocm_aiter_moe_enabled = ( self.rocm_aiter_moe_enabled = (
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
) )
self.kernel: mk.FusedMoEModularKernel | None = None self.kernel: mk.FusedMoEKernel | None = None
self._is_monolithic = ( self._is_monolithic = (
current_platform.is_cpu() current_platform.is_cpu()
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
...@@ -107,7 +107,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -107,7 +107,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None: ) -> FusedMoEPrepareAndFinalizeModular | None:
if self.unquantized_backend == UnquantizedMoeBackend.AITER: if self.unquantized_backend == UnquantizedMoeBackend.AITER:
return None return None
else: else:
...@@ -115,9 +115,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -115,9 +115,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEExpertsModular:
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
if ( if (
prepare_finalize.activation_format prepare_finalize.activation_format
...@@ -325,7 +325,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -325,7 +325,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None assert self.kernel is not None
return self.kernel( return self.kernel.apply(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
......
...@@ -23,7 +23,7 @@ if current_platform.is_xpu(): ...@@ -23,7 +23,7 @@ if current_platform.is_xpu():
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute): class XPUExperts(mk.FusedMoEExpertsModular):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
......
...@@ -19,8 +19,8 @@ from vllm.logger import init_logger ...@@ -19,8 +19,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEActivationFormat, FusedMoEActivationFormat,
FusedMoEExpertsModular,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod, UnquantizedFusedMoEMethod,
) )
...@@ -40,7 +40,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( ...@@ -40,7 +40,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, fused_marlin_moe,
) )
from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
...@@ -59,18 +58,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress ...@@ -59,18 +58,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP, WNA16_SUPPORTED_TYPES_MAP,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import (
flashinfer_trtllm_mxint4_moe, flashinfer_trtllm_mxint4_moe,
is_flashinfer_mxint4_moe_available, is_flashinfer_mxint4_moe_available,
prepare_static_weights_for_trtllm_mxint4_moe, prepare_static_weights_for_trtllm_mxint4_moe,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe, process_fp8_input_tensor_strategy_moe,
process_fp8_weight_tensor_strategy_moe, process_fp8_weight_tensor_strategy_moe,
...@@ -336,7 +328,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -336,7 +328,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None: if self.moe_quant_config is not None:
self.moe_mk = make_nvfp4_moe_kernel( self.moe_kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
...@@ -352,8 +344,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -352,8 +344,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None assert self.moe_kernel is not None
return self.moe_mk( return self.moe_kernel.apply(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -562,43 +554,27 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -562,43 +554,27 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = a13_scale layer.w13_input_scale = a13_scale
layer.w2_input_scale = a2_scale layer.w2_input_scale = a2_scale
# Setup modular kernel for TP case and naive DP/EP case. # Setup modular kernel.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: assert self.experts_cls is not None
assert self.experts_cls is not None self.moe_kernel = make_nvfp4_moe_kernel(
self.moe_mk = make_nvfp4_moe_kernel( moe_quant_config=self.moe_quant_config,
moe_quant_config=self.moe_quant_config, moe_config=self.moe,
moe_config=self.moe, experts_cls=self.experts_cls,
experts_cls=self.experts_cls, shared_experts=layer.shared_experts,
shared_experts=layer.shared_experts, routing_tables=layer._maybe_init_expert_routing_tables(),
routing_tables=layer._maybe_init_expert_routing_tables(), )
)
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError( raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization " f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called." "logic. This function should not be called."
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return make_nvfp4_moe_quant_config( return make_nvfp4_moe_quant_config(
backend=self.nvfp4_backend, backend=self.nvfp4_backend,
w13_scale=layer.w13_weight_scale, w13_scale=layer.w13_weight_scale,
...@@ -609,13 +585,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -609,13 +585,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
) )
@property
def is_monolithic(self) -> bool:
return (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not self.moe.moe_parallel_config.enable_eplb
)
def apply_monolithic( def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
...@@ -623,24 +592,20 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -623,24 +592,20 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic assert self.is_monolithic
assert layer.activation == MoEActivation.SILU, ( assert self.moe_kernel is not None
f"Only SiLU activation is supported, not {layer.activation}." return self.moe_kernel.apply_monolithic(
) x,
assert ( layer.w13_weight,
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM layer.w2_weight,
and not layer.enable_eplb router_logits,
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
router_logits=router_logits,
top_k=layer.top_k,
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group, topk_group=layer.topk_group,
custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias, e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
) )
def apply( def apply(
...@@ -651,34 +616,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -651,34 +616,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert self.moe_kernel is not None
return self.moe_kernel.apply(
# EPLB path x,
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: layer.w13_weight,
assert layer.enable_eplb layer.w2_weight,
return flashinfer_trtllm_fp4_routed_moe( topk_weights,
layer=layer, topk_ids,
x=x, activation=layer.activation,
topk_ids=topk_ids, global_num_experts=layer.global_num_experts,
topk_weights=topk_weights, expert_map=layer.expert_map,
top_k=layer.top_k, apply_router_weight_on_input=layer.apply_router_weight_on_input,
activation=layer.activation, shared_experts_input=shared_experts_input,
global_num_experts=layer.global_num_experts, )
)
else:
assert self.moe_mk is not None
return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...@@ -966,7 +916,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -966,7 +916,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config:
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk = make_fp8_moe_kernel( self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
...@@ -978,94 +928,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -978,94 +928,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError( raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization " f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called." "logic. This function should not be called."
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
self, layer: torch.nn.Module is_per_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
) -> FusedMoEQuantConfig | None:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
return make_fp8_moe_quant_config( return make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
w1_scale=w1_scale, w1_scale=layer.w13_weight_scale,
w2_scale=w2_scale, w2_scale=layer.w2_weight_scale,
a1_scale=a1_scale, a1_scale=layer.w13_input_scale,
a2_scale=a2_scale, a2_scale=layer.w2_input_scale,
per_act_token_quant=( per_act_token_quant=is_per_token,
self.input_quant.strategy == QuantizationStrategy.TOKEN per_out_ch_quant=is_per_token,
),
per_out_ch_quant=(self.input_quant.strategy == QuantizationStrategy.TOKEN),
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
) )
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic( def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic assert self.moe_kernel is not None
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM return self.moe_kernel.apply_monolithic(
assert layer.activation == MoEActivation.SILU, ( x,
f"Only SiLU activation is supported, not {layer.activation}." layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
) )
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=layer.routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
...@@ -1075,8 +978,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1075,8 +978,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert not self.is_monolithic
assert self.moe_mk is not None assert self.moe_kernel is not None
return self.moe_mk( return self.moe_kernel.apply(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1652,9 +1555,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1652,9 +1555,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEExpertsModular:
assert self.num_bits == 4, "only supporting w4" assert self.num_bits == 4, "only supporting w4"
layer.w13_weight = layer.w13_weight_packed layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed layer.w2_weight = layer.w2_weight_packed
...@@ -1943,9 +1846,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1943,9 +1846,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEExpertsModular:
if self.moe.is_lora_enabled: if self.moe.is_lora_enabled:
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
...@@ -2527,7 +2430,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2527,7 +2430,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
return super().maybe_make_prepare_finalize(routing_tables) return super().maybe_make_prepare_finalize(routing_tables)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -2548,9 +2451,9 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2548,9 +2451,9 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEExpertsModular:
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
assert ( assert (
prepare_finalize.activation_format == FusedMoEActivationFormat.Standard prepare_finalize.activation_format == FusedMoEActivationFormat.Standard
...@@ -2558,7 +2461,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2558,7 +2461,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8 from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8
experts: FusedMoEPermuteExpertsUnpermute experts: FusedMoEExpertsModular
logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__) logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__)
experts = CutlassExpertsW4A8Fp8( experts = CutlassExpertsW4A8Fp8(
......
...@@ -23,17 +23,13 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -23,17 +23,13 @@ from vllm.model_executor.layers.batch_invariant import (
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
MoEActivation,
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
...@@ -50,9 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -50,9 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
create_fp8_input_scale, create_fp8_input_scale,
...@@ -860,14 +853,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -860,14 +853,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale) replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config:
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk = make_fp8_moe_kernel( self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
...@@ -930,29 +919,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -930,29 +919,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
raise ValueError( raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization " f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called." "logic. This function should not be called."
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# TRTLLM does not use Modular Kernel.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}") w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}") w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale a1_scale = layer.w13_input_scale
...@@ -983,10 +956,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -983,10 +956,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
return True return True
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic( def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
...@@ -994,50 +963,22 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -994,50 +963,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
# TODO(rob): convert this to MK. x,
if layer.enable_eplb: layer.w13_weight,
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.") layer.w2_weight,
assert layer.activation == MoEActivation.SILU, ( router_logits,
f"Expected 'silu' activation but got {layer.activation}" activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
) )
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=layer.routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
...@@ -1046,9 +987,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1046,9 +987,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
assert not self.is_monolithic assert not self.is_monolithic
return self.moe_mk( assert self.moe_kernel is not None
return self.moe_kernel.apply(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
......
...@@ -13,7 +13,6 @@ from vllm.model_executor.kernels.linear import ( ...@@ -13,7 +13,6 @@ from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -24,14 +23,12 @@ from vllm.model_executor.layers.fused_moe.layer import ( ...@@ -24,14 +23,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
) )
from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format, convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
...@@ -49,13 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -49,13 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
process_fp8_input_tensor_strategy_moe, process_fp8_input_tensor_strategy_moe,
...@@ -746,7 +736,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -746,7 +736,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError( raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization " f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called." "logic. This function should not be called."
...@@ -754,9 +744,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -754,9 +744,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEExpertsModular:
raise ValueError( raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization " f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called." "logic. This function should not be called."
...@@ -871,16 +861,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -871,16 +861,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# Setup modular kernel. # Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: assert self.experts_cls is not None
assert self.experts_cls is not None self.moe_kernel = make_fp8_moe_kernel(
self.moe_mk = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config,
moe_quant_config=self.moe_quant_config, moe_config=self.moe,
moe_config=self.moe, fp8_backend=self.fp8_backend,
fp8_backend=self.fp8_backend, experts_cls=self.experts_cls,
experts_cls=self.experts_cls, routing_tables=layer._maybe_init_expert_routing_tables(),
routing_tables=layer._maybe_init_expert_routing_tables(), shared_experts=layer.shared_experts,
shared_experts=layer.shared_experts, )
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13 = layer.w13_weight w13 = layer.w13_weight
...@@ -913,9 +902,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -913,9 +902,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
w1_scale = layer.w13_weight_scale w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale w2_scale = layer.w2_weight_scale
a1_scale = layer.w13_input_scale a1_scale = layer.w13_input_scale
...@@ -929,10 +916,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -929,10 +916,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
a2_scale=a2_scale, a2_scale=a2_scale,
) )
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic( def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
...@@ -940,28 +923,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -940,28 +923,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM assert self.moe_kernel is not None
if layer.enable_eplb: return self.moe_kernel.apply_monolithic(
raise NotImplementedError( x,
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend." layer.w13_weight,
) layer.w2_weight,
# TODO(rob): this validation should happen at kernel selection router_logits,
# time in the oracle rather than here. activation=layer.activation,
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert layer.activation in SUPPORTED_ACTIVATIONS, (
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"TRTLLM FP4 MoE, {layer.activation} found instead."
)
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
top_k=layer.top_k, expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group, topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input, e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
) )
def apply( def apply(
...@@ -973,25 +948,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -973,25 +948,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert not self.is_monolithic
assert self.moe_kernel is not None
# TODO(rob): this validation should happen at kernel selection return self.moe_kernel.apply(
# time in the oracle rather than here. x,
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: layer.w13_weight,
assert layer.activation in ( layer.w2_weight,
MoEActivation.SILU, topk_weights,
MoEActivation.RELU2_NO_MUL, topk_ids,
), (
"Expected activation to be in ('silu', 'relu2_no_mul'),"
f"but got {layer.activation}"
)
assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
...@@ -1235,17 +1198,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1235,17 +1198,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError( raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization " f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called." "logic. This function should not be called."
...@@ -1420,51 +1373,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1420,51 +1373,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2) replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
replace_parameter(layer, "w2_input_scale", a2_scale) replace_parameter(layer, "w2_input_scale", a2_scale)
# Setup modular kernel for TP case and naive DP/EP case. # Setup modular kernel.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: assert self.experts_cls is not None
assert self.experts_cls is not None self.moe_kernel = make_nvfp4_moe_kernel(
self.moe_mk = make_nvfp4_moe_kernel( moe_quant_config=self.moe_quant_config,
moe_quant_config=self.moe_quant_config, moe_config=self.moe,
moe_config=self.moe, experts_cls=self.experts_cls,
experts_cls=self.experts_cls, shared_experts=layer.shared_experts,
shared_experts=layer.shared_experts, routing_tables=layer._maybe_init_expert_routing_tables(),
routing_tables=layer._maybe_init_expert_routing_tables(),
)
@property
def do_post_quant_allgather(self):
return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
def prepare_dp_allgather_tensor(
self,
layer: FusedMoE,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
raise RuntimeError(
"prepare_dp_allgather_tensor is only supported for "
"FlashInfer TRTLLM NVFP4 MoE backend."
)
import flashinfer
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states,
layer.a1_gscale,
is_sf_swizzled_layout=False,
) )
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
return hidden_states_fp4, extra_tensors
def get_fused_moe_quant_config( def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return make_nvfp4_moe_quant_config( return make_nvfp4_moe_quant_config(
backend=self.nvfp4_backend, backend=self.nvfp4_backend,
w13_scale=layer.w13_weight_scale, w13_scale=layer.w13_weight_scale,
...@@ -1479,13 +1399,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1479,13 +1399,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
return True return True
@property
def is_monolithic(self) -> bool:
return (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not self.moe.moe_parallel_config.enable_eplb
)
def apply_monolithic( def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
...@@ -1493,22 +1406,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1493,22 +1406,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic assert self.is_monolithic
assert ( assert self.moe_kernel is not None
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM return self.moe_kernel.apply_monolithic(
and not layer.enable_eplb x,
) layer.w13_weight,
layer.w2_weight,
return flashinfer_trtllm_fp4_moe( router_logits,
layer=layer,
x=x,
router_logits=router_logits,
top_k=layer.top_k,
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group, topk_group=layer.topk_group,
custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias, e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
) )
def apply( def apply(
...@@ -1520,33 +1431,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1520,33 +1431,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert not self.is_monolithic
assert self.moe_kernel is not None
# EPLB path return self.moe_kernel.apply(
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: x,
assert layer.enable_eplb layer.w13_weight,
return flashinfer_trtllm_fp4_routed_moe( layer.w2_weight,
layer=layer, topk_weights,
x=x, topk_ids,
topk_ids=topk_ids, activation=layer.activation,
topk_weights=topk_weights, global_num_experts=layer.global_num_experts,
top_k=layer.top_k, expert_map=layer.expert_map,
activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, shared_experts_input=shared_experts_input,
) )
else:
assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
......
...@@ -266,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -266,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
) )
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends # Initialized in process_weights_after_loading for CUTLASS/SM90 backends
self.moe_mk: mk.FusedMoEModularKernel | None = None self.moe_kernel: mk.FusedMoEKernel | None = None
def create_weights( def create_weights(
self, self,
...@@ -440,7 +440,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -440,7 +440,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
) )
assert prepare_finalize is not None assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel( self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize, prepare_finalize,
MarlinExperts( MarlinExperts(
self.moe, self.moe,
...@@ -789,7 +789,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -789,7 +789,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
) )
assert prepare_finalize is not None assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel( self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize, prepare_finalize,
FlashInferExperts( FlashInferExperts(
moe_config=self.moe, moe_config=self.moe,
...@@ -954,9 +954,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -954,9 +954,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEExpertsModular:
if ( if (
prepare_finalize.activation_format prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts == mk.FusedMoEActivationFormat.BatchedExperts
...@@ -1043,8 +1043,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -1043,8 +1043,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
or self.mxfp4_backend == Mxfp4Backend.MARLIN or self.mxfp4_backend == Mxfp4Backend.MARLIN
) )
assert self.moe_mk is not None assert self.moe_kernel is not None
return self.moe_mk( return self.moe_kernel.apply(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
......
...@@ -6,28 +6,18 @@ from typing import TYPE_CHECKING ...@@ -6,28 +6,18 @@ from typing import TYPE_CHECKING
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
align_fp4_moe_weights_for_fi, align_fp4_moe_weights_for_fi,
) )
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_cutlass_fused_moe,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.layer import FusedMoE
...@@ -42,92 +32,15 @@ __all__ = [ ...@@ -42,92 +32,15 @@ __all__ = [
"reorder_w1w3_to_w3w1", "reorder_w1w3_to_w3w1",
] ]
#
# Methods used by the oracle for kernel selection.
#
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
def _supports_no_act_and_mul() -> bool:
"""Supports non-gated MoE."""
return True
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Nvfp4 quantization."""
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
def _supports_routing_method(
routing_method: RoutingMethodType,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.Llama4,
]
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""
TRTLLM is a monolithic kernel that requires dispatch_router_logits() for
the naive dispatch/combine path. DeepEP HT only implements dispatch() for
the modular kernel path, so TRTLLM is incompatible with DeepEP HT.
"""
return not moe_parallel_config.use_deepep_ht_kernels
def is_supported_config_trtllm(
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason(f"current device {current_platform.device_name}")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_quant_scheme(weight_key, activation_key):
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
elif not _supports_routing_method(moe_config.routing_method):
return False, _make_reason(f"routing method {moe_config.routing_method}")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason(f"activation format {activation_format}")
elif moe_config.hidden_dim % 512 != 0:
return False, _make_reason(
f"hidden_dim must be divisible by 512, found {moe_config.hidden_dim}"
)
return True, None def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.has_device_capability(100)
)
def reorder_w1w3_to_w3w1( def reorder_w1w3_to_w3w1(
...@@ -276,190 +189,6 @@ def prepare_static_weights_for_trtllm_fp4_moe( ...@@ -276,190 +189,6 @@ def prepare_static_weights_for_trtllm_fp4_moe(
) )
def flashinfer_trtllm_fp4_moe(
layer: torch.nn.Module,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor,
top_k: int,
activation: MoEActivation,
global_num_experts: int,
num_expert_group: int | None,
topk_group: int | None,
custom_routing_function: object | None,
e_score_correction_bias: torch.Tensor | None,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
router_logits: Router logits for expert selection
top_k: Number of experts to select per token
activation: Activation function to use
global_num_experts: Total number of experts across all ranks
num_expert_group: Number of expert groups (for grouped routing)
topk_group: Top-k within each group
custom_routing_function: Custom routing function (e.g., Llama4)
e_score_correction_bias: Optional routing bias correction
Returns:
Output tensor from the MoE layer
"""
import flashinfer
from vllm.model_executor.models.llama4 import Llama4MoE
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert activation in SUPPORTED_ACTIVATIONS, (
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"TRTLLM FP4 MoE, {activation} found instead."
)
# Quantize input to FP4
if isinstance(x, tuple):
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# hidden_states is the already quantized
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, layer.a1_gscale, is_sf_swizzled_layout=False
)
# Determine routing method type
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
routing_method_type = layer.routing_method_type
if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4
# Cast to Fp32 (required by kernel).
router_logits = (
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
)
# Determine activation type
activation_type = activation_to_flashinfer_int(layer.activation)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).reshape(*hidden_states_fp4.shape[:-1], -1),
gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group if num_expert_group is not None else 0,
topk_group=topk_group if topk_group is not None else 0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
routing_method_type=routing_method_type,
do_finalize=True,
activation_type=activation_type,
)[0]
return out
def flashinfer_trtllm_fp4_routed_moe(
layer: torch.nn.Module,
x: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
top_k: int,
activation: MoEActivation,
global_num_experts: int,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
input top k expert indices and scores rather than computing
top k expert indices from scores.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
topk_ids: Ids of selected experts
top_k: Number of experts to select per token
activation: Activation function to use
global_num_experts: Total number of experts across all ranks
Returns:
Output tensor from the MoE layer
"""
import flashinfer
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
assert activation == MoEActivation.SILU, (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
f"{activation} found instead."
)
# Pack top k ids and expert weights into a single int32 tensor, as
# required by TRT-LLM
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
).view(torch.int16)
if isinstance(x, tuple):
# Hidden_states is the already quantized
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# Quantize input to FP4
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, layer.a1_gscale, is_sf_swizzled_layout=False
)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
topk_ids=packed_tensor,
routing_bias=None,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).reshape(*hidden_states_fp4.shape[:-1], -1),
gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=0,
topk_group=0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
routing_method_type=1,
do_finalize=True,
)[0]
return out
def prepare_nvfp4_moe_layer_for_fi_or_cutlass( def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
backend: "NvFp4MoeBackend", backend: "NvFp4MoeBackend",
layer: "FusedMoE", layer: "FusedMoE",
...@@ -526,6 +255,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( ...@@ -526,6 +255,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
) )
) )
layer.intermediate_size_per_partition = padded_intermediate layer.intermediate_size_per_partition = padded_intermediate
layer.moe_config.intermediate_size_per_partition = padded_intermediate
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe( w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
w13, w13,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING
import torch import torch
...@@ -10,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation ...@@ -10,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
if TYPE_CHECKING:
from flashinfer.fused_moe.core import ActivationType
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum): ...@@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum):
def activation_to_flashinfer_int(activation: MoEActivation) -> int: def activation_to_flashinfer_int(activation: MoEActivation) -> int:
return activation_to_flashinfer_type(activation).value
def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType":
from flashinfer.fused_moe.core import ActivationType from flashinfer.fused_moe.core import ActivationType
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively # silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
...@@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int: ...@@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int:
MoEActivation.GELU: ActivationType.Geglu, MoEActivation.GELU: ActivationType.Geglu,
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2, MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
} }
return ACTIVATION_TO_FI_ACTIVATION[activation].value return ACTIVATION_TO_FI_ACTIVATION[activation]
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
...@@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( ...@@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
) )
def register_scales_for_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> None:
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
)
layer.w2_input_scale_inv = 1.0 / w2_input_scale
layer.output1_scales_gate_scalar = g1_alphas
if layer.activation.is_gated:
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
else:
layer.output1_scales_scalar = (
torch.ones_like(g1_alphas) * layer.w2_input_scale_inv
)
layer.output2_scales_scalar = g2_alphas
def apply_fi_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
global_num_experts: int,
apply_router_weight_on_input: bool,
) -> torch.Tensor:
from flashinfer.fused_moe import RoutingMethodType
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
from vllm.model_executor.models.llama4 import Llama4MoE
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
assert (
hasattr(layer, "output1_scales_scalar")
and hasattr(layer, "output1_scales_gate_scalar")
and hasattr(layer, "output2_scales_scalar")
)
if layer.routing_method_type == RoutingMethodType.Llama4:
assert (
not layer.renormalize
and layer.custom_routing_function == Llama4MoE.custom_routing_function
), (
"FusedMoE flashinfer kernels with Llama4 routing method are only "
"supported for Llama4"
)
else:
assert layer.custom_routing_function is None, (
"Custom routing function is only supported for Llama4"
)
activation_type = activation_to_flashinfer_int(layer.activation)
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
input_scale=layer.w13_input_scale,
gemm1_weights=layer.w13_weight,
gemm2_weights=layer.w2_weight,
output1_scales_scalar=layer.output1_scales_scalar,
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
output2_scales_scalar=layer.output2_scales_scalar,
num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=layer.routing_method_type,
activation_type=activation_type,
)
def make_fp8_moe_alpha_scales_for_fi(
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
g1_alphas = (w13_scale * w13_input_scale).squeeze()
g2_alphas = (w2_scale * w2_input_scale).squeeze()
return g1_alphas, g2_alphas
def get_flashinfer_moe_backend() -> FlashinferMoeBackend: def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = { backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS, "throughput": FlashinferMoeBackend.CUTLASS,
...@@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi( ...@@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi(
min_alignment, min_alignment,
) )
layer.intermediate_size_per_partition = new_intermediate layer.intermediate_size_per_partition = new_intermediate
layer.moe_config.intermediate_size_per_partition = new_intermediate
# FI kernels require W31 layout rather than W13. # FI kernels require W31 layout rather than W13.
if layer.moe_config.is_act_and_mul: if layer.moe_config.is_act_and_mul:
...@@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi( ...@@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi(
w13_scale = swap_w13_to_w31(w13_scale) w13_scale = swap_w13_to_w31(w13_scale)
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle # FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
# and registration of alpha scales. Note that we do not register # and registration of alpha scales.
# as nn.Parameters since they are not needed for weight-reloading.
if is_trtllm and not block_quant: if is_trtllm and not block_quant:
assert w13_input_scale is not None assert w13_input_scale is not None
assert w2_input_scale is not None assert w2_input_scale is not None
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated) rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
)
# Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel. # Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel.
# Some FP8 models have near-zero block scales (~1e-23) for dead/unused # Some FP8 models have near-zero block scales (~1e-23) for dead/unused
......
...@@ -172,7 +172,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: ...@@ -172,7 +172,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
# Further check if the ModularKernel implementation uses the DeepGemmExperts # Further check if the ModularKernel implementation uses the DeepGemmExperts
return isinstance( return isinstance(
module.quant_method.moe_mk, (DeepGemmExperts, TritonOrDeepGemmExperts) module.quant_method.moe_kernel, (DeepGemmExperts, TritonOrDeepGemmExperts)
) )
......
...@@ -88,9 +88,14 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None: ...@@ -88,9 +88,14 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None:
Without autotuning, FlashInfer will rely on heuristics, which may Without autotuning, FlashInfer will rely on heuristics, which may
be significantly slower. be significantly slower.
""" """
from vllm.utils.flashinfer import autotune import vllm.utils.flashinfer as fi_utils
with torch.inference_mode(), fi_utils.autotune():
# Certain FlashInfer kernels (e.g. nvfp4 routed moe) are
# incompatible with autotuning. This state is used to skip
# those kernels during the autotuning process.
fi_utils._is_fi_autotuning = True
with torch.inference_mode(), autotune():
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
# When autotuning with number of tokens m, flashinfer will autotune # When autotuning with number of tokens m, flashinfer will autotune
# operations for all number of tokens up to m. # operations for all number of tokens up to m.
...@@ -100,3 +105,5 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None: ...@@ -100,3 +105,5 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None:
skip_eplb=True, skip_eplb=True,
is_profile=True, is_profile=True,
) )
fi_utils._is_fi_autotuning = False
...@@ -140,6 +140,7 @@ autotune = _lazy_import_wrapper( ...@@ -140,6 +140,7 @@ autotune = _lazy_import_wrapper(
"autotune", "autotune",
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(), fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
) )
_is_fi_autotuning: bool = False
@functools.cache @functools.cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment