"vscode:/vscode.git/clone" did not exist on "f6818a92cbad18b7bac0b5cb424549ba941e11b9"
Unverified Commit 97995f63 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Create MK for TRTLLM Kernels (#32564)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
parent 881a6b01
......@@ -320,8 +320,8 @@ class DefaultMoERunner(MoERunner):
"""
assert self.quant_method is not None
return (
self.quant_method.moe_mk is not None
and self.quant_method.moe_mk.output_is_reduced()
self.quant_method.moe_kernel is not None
and self.quant_method.moe_kernel.output_is_reduced()
)
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
......@@ -640,45 +640,6 @@ class DefaultMoERunner(MoERunner):
)
with sp_ctx:
extra_tensors = None
if do_naive_dispatch_combine:
post_quant_allgather = (
self.quant_method is not None
and self.moe_config.dp_size > 1
and self.moe_config.use_ep
and getattr(self.quant_method, "do_post_quant_allgather", False)
)
if post_quant_allgather:
hidden_states_to_dispatch, extra_tensors = (
self.quant_method.prepare_dp_allgather_tensor(
layer, hidden_states, router_logits
)
)
else:
hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch_router_logits(
hidden_states_to_dispatch,
router_logits,
self.moe_config.is_sequence_parallel,
extra_tensors=extra_tensors,
)
if extra_tensors is not None:
(
orig_hidden_states,
router_logits,
extra_tensors_combined,
) = dispatch_res
hidden_states_combined = (
orig_hidden_states,
extra_tensors_combined[0],
)
else:
hidden_states_combined, router_logits = dispatch_res
orig_hidden_states = hidden_states_combined
else:
orig_hidden_states = hidden_states
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
......@@ -688,6 +649,17 @@ class DefaultMoERunner(MoERunner):
)
shared_output = self.shared_experts(shared_input)
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
# router logits to all experts.
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
hidden_states,
router_logits,
self.moe_config.is_sequence_parallel,
)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstract to better support PCP.
......@@ -701,31 +673,22 @@ class DefaultMoERunner(MoERunner):
dim=0,
)
# TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
# Figure out nicer way to do this.
if do_naive_dispatch_combine:
x = hidden_states_combined
x_orig = orig_hidden_states
else:
x = hidden_states
x_orig = hidden_states
# Matrix multiply.
if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=layer,
x=x,
x=hidden_states,
router_logits=router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=x_orig,
hidden_states=hidden_states,
router_logits=router_logits,
)
final_hidden_states = self.quant_method.apply(
layer=layer,
x=x, # The type signture of this is wrong due to the hack.
x=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_input,
......
......@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
"""
Useful in the case when some FusedMoEPermuteExpertsUnpermute
Useful in the case when some FusedMoEExpertsModular
implementation does not perform weight application and reduction
but cannot address the needs of all the compatible PrepareAndFinalize
implementations.
......@@ -62,7 +62,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
if output is None:
return fused_expert_output
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
# MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output`
# tensor.
assert output.size() == fused_expert_output.size(), (
"output shape is expected to match the fused_expert_output shape. "
......
......@@ -32,8 +32,8 @@ class TritonOrCutlassExperts(FallbackExperts):
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEExpertsModular],
type[mk.FusedMoEExpertsModular],
]:
return (CutlassExpertsFp8, TritonExperts)
......@@ -77,7 +77,7 @@ class TritonOrCutlassExperts(FallbackExperts):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
# Small batch fallback for sm100.
if self.is_sm100 and hidden_states.shape[0] <= 8:
return self.fallback_experts
......
......@@ -32,8 +32,8 @@ class TritonOrDeepGemmExperts(FallbackExperts):
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEExpertsModular],
type[mk.FusedMoEExpertsModular],
]:
return (DeepGemmExperts, TritonExperts)
......@@ -79,7 +79,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
return self.experts
else:
......
......@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
"""TensorRT-LLM-based fused MoE expert implementation."""
def __init__(
......
......@@ -24,8 +24,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoEExpertsModular,
FusedMoEPrepareAndFinalizeModular,
)
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
UnquantizedMoeBackend,
......@@ -70,7 +70,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.rocm_aiter_moe_enabled = (
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
)
self.kernel: mk.FusedMoEModularKernel | None = None
self.kernel: mk.FusedMoEKernel | None = None
self._is_monolithic = (
current_platform.is_cpu()
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
......@@ -107,7 +107,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
) -> FusedMoEPrepareAndFinalizeModular | None:
if self.unquantized_backend == UnquantizedMoeBackend.AITER:
return None
else:
......@@ -115,9 +115,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
) -> FusedMoEExpertsModular:
assert self.moe_quant_config is not None
if (
prepare_finalize.activation_format
......@@ -325,7 +325,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None
return self.kernel(
return self.kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......
......@@ -23,7 +23,7 @@ if current_platform.is_xpu():
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
class XPUExperts(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,
......
......@@ -19,8 +19,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEActivationFormat,
FusedMoEExpertsModular,
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod,
)
......@@ -40,7 +40,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
......@@ -59,18 +58,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import (
flashinfer_trtllm_mxint4_moe,
is_flashinfer_mxint4_moe_available,
prepare_static_weights_for_trtllm_mxint4_moe,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe,
process_fp8_weight_tensor_strategy_moe,
......@@ -336,7 +328,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None:
self.moe_mk = make_nvfp4_moe_kernel(
self.moe_kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
......@@ -352,8 +344,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
return self.moe_mk(
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -562,43 +554,27 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = a13_scale
layer.w2_input_scale = a2_scale
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
# Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
assert self.experts_cls is not None
self.moe_kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
return make_nvfp4_moe_quant_config(
backend=self.nvfp4_backend,
w13_scale=layer.w13_weight_scale,
......@@ -609,13 +585,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
a2_scale=layer.w2_input_scale,
)
@property
def is_monolithic(self) -> bool:
return (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not self.moe.moe_parallel_config.enable_eplb
)
def apply_monolithic(
self,
layer: FusedMoE,
......@@ -623,24 +592,20 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)
assert (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
router_logits=router_logits,
top_k=layer.top_k,
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
......@@ -651,34 +616,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
# EPLB path
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
assert layer.enable_eplb
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=topk_ids,
topk_weights=topk_weights,
top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
)
else:
assert self.moe_mk is not None
return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
......@@ -966,7 +916,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk = make_fp8_moe_kernel(
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
......@@ -978,94 +928,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
is_per_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
return make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=(
self.input_quant.strategy == QuantizationStrategy.TOKEN
),
per_out_ch_quant=(self.input_quant.strategy == QuantizationStrategy.TOKEN),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=is_per_token,
per_out_ch_quant=is_per_token,
block_shape=self.weight_block_size,
)
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=layer.routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
def apply(
self,
layer: FusedMoE,
......@@ -1075,8 +978,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.moe_mk is not None
return self.moe_mk(
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -1652,9 +1555,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
assert self.num_bits == 4, "only supporting w4"
layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed
......@@ -1943,9 +1846,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
if self.moe.is_lora_enabled:
assert self.moe_quant_config is not None
from vllm.triton_utils import HAS_TRITON
......@@ -2527,7 +2430,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
return super().maybe_make_prepare_finalize(routing_tables)
def get_fused_moe_quant_config(
......@@ -2548,9 +2451,9 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
assert self.moe_quant_config is not None
assert (
prepare_finalize.activation_format == FusedMoEActivationFormat.Standard
......@@ -2558,7 +2461,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8
experts: FusedMoEPermuteExpertsUnpermute
experts: FusedMoEExpertsModular
logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__)
experts = CutlassExpertsW4A8Fp8(
......
......@@ -23,17 +23,13 @@ from vllm.model_executor.layers.batch_invariant import (
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported,
MoEActivation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
......@@ -50,9 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
......@@ -860,14 +853,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk = make_fp8_moe_kernel(
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
......@@ -930,29 +919,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# TRTLLM does not use Modular Kernel.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
......@@ -983,10 +956,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def supports_eplb(self) -> bool:
return True
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
layer: FusedMoE,
......@@ -994,50 +963,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
# TODO(rob): convert this to MK.
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == MoEActivation.SILU, (
f"Expected 'silu' activation but got {layer.activation}"
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=layer.routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
def apply(
self,
layer: FusedMoE,
......@@ -1046,9 +987,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
assert not self.is_monolithic
return self.moe_mk(
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
......
......@@ -13,7 +13,6 @@ from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
......@@ -24,14 +23,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel,
......@@ -49,13 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
process_fp8_input_tensor_strategy_moe,
......@@ -746,7 +736,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
......@@ -754,9 +744,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
......@@ -871,16 +861,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
assert self.experts_cls is not None
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13 = layer.w13_weight
......@@ -913,9 +902,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
a1_scale = layer.w13_input_scale
......@@ -929,10 +916,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
a2_scale=a2_scale,
)
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
layer: FusedMoE,
......@@ -940,28 +923,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
)
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert layer.activation in SUPPORTED_ACTIVATIONS, (
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"TRTLLM FP4 MoE, {layer.activation} found instead."
)
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
......@@ -973,25 +948,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
assert layer.activation in (
MoEActivation.SILU,
MoEActivation.RELU2_NO_MUL,
), (
"Expected activation to be in ('silu', 'relu2_no_mul'),"
f"but got {layer.activation}"
)
assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
......@@ -1235,17 +1198,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
......@@ -1420,51 +1373,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
replace_parameter(layer, "w2_input_scale", a2_scale)
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
# Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
@property
def do_post_quant_allgather(self):
return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
def prepare_dp_allgather_tensor(
self,
layer: FusedMoE,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
raise RuntimeError(
"prepare_dp_allgather_tensor is only supported for "
"FlashInfer TRTLLM NVFP4 MoE backend."
)
import flashinfer
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states,
layer.a1_gscale,
is_sf_swizzled_layout=False,
assert self.experts_cls is not None
self.moe_kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
return hidden_states_fp4, extra_tensors
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
return make_nvfp4_moe_quant_config(
backend=self.nvfp4_backend,
w13_scale=layer.w13_weight_scale,
......@@ -1479,13 +1399,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def supports_eplb(self) -> bool:
return True
@property
def is_monolithic(self) -> bool:
return (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not self.moe.moe_parallel_config.enable_eplb
)
def apply_monolithic(
self,
layer: FusedMoE,
......@@ -1493,22 +1406,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
router_logits=router_logits,
top_k=layer.top_k,
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
......@@ -1520,33 +1431,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
# EPLB path
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
assert layer.enable_eplb
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=topk_ids,
topk_weights=topk_weights,
top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
)
else:
assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
......
......@@ -266,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
self.moe_mk: mk.FusedMoEModularKernel | None = None
self.moe_kernel: mk.FusedMoEKernel | None = None
def create_weights(
self,
......@@ -440,7 +440,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel(
self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize,
MarlinExperts(
self.moe,
......@@ -789,7 +789,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel(
self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize,
FlashInferExperts(
moe_config=self.moe,
......@@ -954,9 +954,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
......@@ -1043,8 +1043,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
or self.mxfp4_backend == Mxfp4Backend.MARLIN
)
assert self.moe_mk is not None
return self.moe_mk(
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......
......@@ -6,28 +6,18 @@ from typing import TYPE_CHECKING
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
align_fp4_moe_weights_for_fi,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_cutlass_fused_moe,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
......@@ -42,92 +32,15 @@ __all__ = [
"reorder_w1w3_to_w3w1",
]
#
# Methods used by the oracle for kernel selection.
#
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
def _supports_no_act_and_mul() -> bool:
"""Supports non-gated MoE."""
return True
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Nvfp4 quantization."""
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
def _supports_routing_method(
routing_method: RoutingMethodType,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.Llama4,
]
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""
TRTLLM is a monolithic kernel that requires dispatch_router_logits() for
the naive dispatch/combine path. DeepEP HT only implements dispatch() for
the modular kernel path, so TRTLLM is incompatible with DeepEP HT.
"""
return not moe_parallel_config.use_deepep_ht_kernels
def is_supported_config_trtllm(
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason(f"current device {current_platform.device_name}")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_quant_scheme(weight_key, activation_key):
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
elif not _supports_routing_method(moe_config.routing_method):
return False, _make_reason(f"routing method {moe_config.routing_method}")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason(f"activation format {activation_format}")
elif moe_config.hidden_dim % 512 != 0:
return False, _make_reason(
f"hidden_dim must be divisible by 512, found {moe_config.hidden_dim}"
)
return True, None
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.has_device_capability(100)
)
def reorder_w1w3_to_w3w1(
......@@ -276,190 +189,6 @@ def prepare_static_weights_for_trtllm_fp4_moe(
)
def flashinfer_trtllm_fp4_moe(
layer: torch.nn.Module,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor,
top_k: int,
activation: MoEActivation,
global_num_experts: int,
num_expert_group: int | None,
topk_group: int | None,
custom_routing_function: object | None,
e_score_correction_bias: torch.Tensor | None,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
router_logits: Router logits for expert selection
top_k: Number of experts to select per token
activation: Activation function to use
global_num_experts: Total number of experts across all ranks
num_expert_group: Number of expert groups (for grouped routing)
topk_group: Top-k within each group
custom_routing_function: Custom routing function (e.g., Llama4)
e_score_correction_bias: Optional routing bias correction
Returns:
Output tensor from the MoE layer
"""
import flashinfer
from vllm.model_executor.models.llama4 import Llama4MoE
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert activation in SUPPORTED_ACTIVATIONS, (
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"TRTLLM FP4 MoE, {activation} found instead."
)
# Quantize input to FP4
if isinstance(x, tuple):
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# hidden_states is the already quantized
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, layer.a1_gscale, is_sf_swizzled_layout=False
)
# Determine routing method type
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
routing_method_type = layer.routing_method_type
if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4
# Cast to Fp32 (required by kernel).
router_logits = (
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
)
# Determine activation type
activation_type = activation_to_flashinfer_int(layer.activation)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).reshape(*hidden_states_fp4.shape[:-1], -1),
gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group if num_expert_group is not None else 0,
topk_group=topk_group if topk_group is not None else 0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
routing_method_type=routing_method_type,
do_finalize=True,
activation_type=activation_type,
)[0]
return out
def flashinfer_trtllm_fp4_routed_moe(
layer: torch.nn.Module,
x: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
top_k: int,
activation: MoEActivation,
global_num_experts: int,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
input top k expert indices and scores rather than computing
top k expert indices from scores.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
topk_ids: Ids of selected experts
top_k: Number of experts to select per token
activation: Activation function to use
global_num_experts: Total number of experts across all ranks
Returns:
Output tensor from the MoE layer
"""
import flashinfer
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
assert activation == MoEActivation.SILU, (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
f"{activation} found instead."
)
# Pack top k ids and expert weights into a single int32 tensor, as
# required by TRT-LLM
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
).view(torch.int16)
if isinstance(x, tuple):
# Hidden_states is the already quantized
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# Quantize input to FP4
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, layer.a1_gscale, is_sf_swizzled_layout=False
)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
topk_ids=packed_tensor,
routing_bias=None,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).reshape(*hidden_states_fp4.shape[:-1], -1),
gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=0,
topk_group=0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
routing_method_type=1,
do_finalize=True,
)[0]
return out
def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
backend: "NvFp4MoeBackend",
layer: "FusedMoE",
......@@ -526,6 +255,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
)
)
layer.intermediate_size_per_partition = padded_intermediate
layer.moe_config.intermediate_size_per_partition = padded_intermediate
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
w13,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import TYPE_CHECKING
import torch
......@@ -10,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
if TYPE_CHECKING:
from flashinfer.fused_moe.core import ActivationType
logger = init_logger(__name__)
......@@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum):
def activation_to_flashinfer_int(activation: MoEActivation) -> int:
return activation_to_flashinfer_type(activation).value
def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType":
from flashinfer.fused_moe.core import ActivationType
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
......@@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int:
MoEActivation.GELU: ActivationType.Geglu,
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
}
return ACTIVATION_TO_FI_ACTIVATION[activation].value
return ACTIVATION_TO_FI_ACTIVATION[activation]
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
......@@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
)
def register_scales_for_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> None:
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
)
layer.w2_input_scale_inv = 1.0 / w2_input_scale
layer.output1_scales_gate_scalar = g1_alphas
if layer.activation.is_gated:
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
else:
layer.output1_scales_scalar = (
torch.ones_like(g1_alphas) * layer.w2_input_scale_inv
)
layer.output2_scales_scalar = g2_alphas
def apply_fi_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
global_num_experts: int,
apply_router_weight_on_input: bool,
) -> torch.Tensor:
from flashinfer.fused_moe import RoutingMethodType
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
from vllm.model_executor.models.llama4 import Llama4MoE
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
assert (
hasattr(layer, "output1_scales_scalar")
and hasattr(layer, "output1_scales_gate_scalar")
and hasattr(layer, "output2_scales_scalar")
)
if layer.routing_method_type == RoutingMethodType.Llama4:
assert (
not layer.renormalize
and layer.custom_routing_function == Llama4MoE.custom_routing_function
), (
"FusedMoE flashinfer kernels with Llama4 routing method are only "
"supported for Llama4"
)
else:
assert layer.custom_routing_function is None, (
"Custom routing function is only supported for Llama4"
)
activation_type = activation_to_flashinfer_int(layer.activation)
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
input_scale=layer.w13_input_scale,
gemm1_weights=layer.w13_weight,
gemm2_weights=layer.w2_weight,
output1_scales_scalar=layer.output1_scales_scalar,
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
output2_scales_scalar=layer.output2_scales_scalar,
num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=layer.routing_method_type,
activation_type=activation_type,
)
def make_fp8_moe_alpha_scales_for_fi(
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
g1_alphas = (w13_scale * w13_input_scale).squeeze()
g2_alphas = (w2_scale * w2_input_scale).squeeze()
return g1_alphas, g2_alphas
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
......@@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi(
min_alignment,
)
layer.intermediate_size_per_partition = new_intermediate
layer.moe_config.intermediate_size_per_partition = new_intermediate
# FI kernels require W31 layout rather than W13.
if layer.moe_config.is_act_and_mul:
......@@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi(
w13_scale = swap_w13_to_w31(w13_scale)
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
# and registration of alpha scales. Note that we do not register
# as nn.Parameters since they are not needed for weight-reloading.
# and registration of alpha scales.
if is_trtllm and not block_quant:
assert w13_input_scale is not None
assert w2_input_scale is not None
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
)
# Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel.
# Some FP8 models have near-zero block scales (~1e-23) for dead/unused
......
......@@ -172,7 +172,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
# Further check if the ModularKernel implementation uses the DeepGemmExperts
return isinstance(
module.quant_method.moe_mk, (DeepGemmExperts, TritonOrDeepGemmExperts)
module.quant_method.moe_kernel, (DeepGemmExperts, TritonOrDeepGemmExperts)
)
......
......@@ -88,9 +88,14 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None:
Without autotuning, FlashInfer will rely on heuristics, which may
be significantly slower.
"""
from vllm.utils.flashinfer import autotune
import vllm.utils.flashinfer as fi_utils
with torch.inference_mode(), fi_utils.autotune():
# Certain FlashInfer kernels (e.g. nvfp4 routed moe) are
# incompatible with autotuning. This state is used to skip
# those kernels during the autotuning process.
fi_utils._is_fi_autotuning = True
with torch.inference_mode(), autotune():
# We skip EPLB here since we don't want to record dummy metrics
# When autotuning with number of tokens m, flashinfer will autotune
# operations for all number of tokens up to m.
......@@ -100,3 +105,5 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None:
skip_eplb=True,
is_profile=True,
)
fi_utils._is_fi_autotuning = False
......@@ -140,6 +140,7 @@ autotune = _lazy_import_wrapper(
"autotune",
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
)
_is_fi_autotuning: bool = False
@functools.cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment