Commit d76fc11e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc1' into v0.15.0rc1-dev

parents 38166ec4 58996f35
...@@ -576,9 +576,6 @@ class FusedMoE(CustomOp): ...@@ -576,9 +576,6 @@ class FusedMoE(CustomOp):
device=vllm_config.device_config.device, device=vllm_config.device_config.device,
routing_method=self.routing_method_type, routing_method=self.routing_method_type,
) )
self.moe_config_use_flashinfer_cutlass_kernels = (
self.moe_config.use_flashinfer_cutlass_kernels
)
if self.use_mori_kernels: if self.use_mori_kernels:
assert self.rocm_aiter_fmoe_enabled, ( assert self.rocm_aiter_fmoe_enabled, (
"Mori needs to be used with aiter fused_moe for now." "Mori needs to be used with aiter fused_moe for now."
...@@ -671,6 +668,11 @@ class FusedMoE(CustomOp): ...@@ -671,6 +668,11 @@ class FusedMoE(CustomOp):
# This is called after all weight loading and post-processing, so it # This is called after all weight loading and post-processing, so it
# should be safe to swap out the quant_method. # should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None: def maybe_init_modular_kernel(self) -> None:
# NOTE(rob): WIP refactor. For quant methods that own the MK
# we create the MK during process_weights_after_loading.
if self.quant_method.supports_internal_mk or self.quant_method.is_monolithic:
return None
self.ensure_moe_quant_config_init() self.ensure_moe_quant_config_init()
# routing_tables only needed for round-robin expert placement with # routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend. # DeepEP all2all backend.
...@@ -753,14 +755,6 @@ class FusedMoE(CustomOp): ...@@ -753,14 +755,6 @@ class FusedMoE(CustomOp):
def use_mori_kernels(self): def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels return self.moe_parallel_config.use_mori_kernels
@property
def use_flashinfer_cutlass_kernels(self):
return (
self.moe_quant_config is not None
and self.moe_quant_config.quant_dtype == "nvfp4"
and self.moe_config_use_flashinfer_cutlass_kernels
)
@property @property
def use_marlin_kernels(self): def use_marlin_kernels(self):
return getattr(self.quant_method, "use_marlin", False) return getattr(self.quant_method, "use_marlin", False)
...@@ -771,7 +765,7 @@ class FusedMoE(CustomOp): ...@@ -771,7 +765,7 @@ class FusedMoE(CustomOp):
self.moe_parallel_config.use_pplx_kernels self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels or self.moe_parallel_config.use_mori_kernels
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) or self.moe_parallel_config.use_fi_all2allv_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK ) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@property @property
...@@ -1571,7 +1565,7 @@ class FusedMoE(CustomOp): ...@@ -1571,7 +1565,7 @@ class FusedMoE(CustomOp):
assert self.quant_method is not None assert self.quant_method is not None
return ( return (
isinstance(self.quant_method, FusedMoEModularMethod) isinstance(self.quant_method, FusedMoEModularMethod)
and self.quant_method.fused_experts.output_is_reduced() and self.quant_method.moe_mk.output_is_reduced() # type: ignore[union-attr]
) )
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
...@@ -1835,7 +1829,7 @@ class FusedMoE(CustomOp): ...@@ -1835,7 +1829,7 @@ class FusedMoE(CustomOp):
self.ensure_dp_chunking_init() self.ensure_dp_chunking_init()
has_separate_shared_experts = ( has_separate_shared_experts = (
not isinstance(self.quant_method, FusedMoEModularMethod) not self.quant_method.mk_owns_shared_expert
and self.shared_experts is not None and self.shared_experts is not None
) )
...@@ -1859,8 +1853,10 @@ class FusedMoE(CustomOp): ...@@ -1859,8 +1853,10 @@ class FusedMoE(CustomOp):
hidden_states, router_logits, has_separate_shared_experts hidden_states, router_logits, has_separate_shared_experts
) )
do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance( # NOTE(rob): once we finish migrating all the quant methods to use
self.quant_method, FusedMoEModularMethod # MKs, we can remove the naive dispatch/combine path from here.
do_naive_dispatch_combine = (
self.dp_size > 1 and not self.quant_method.supports_internal_mk
) )
ctx = get_forward_context() ctx = get_forward_context()
...@@ -1888,7 +1884,7 @@ class FusedMoE(CustomOp): ...@@ -1888,7 +1884,7 @@ class FusedMoE(CustomOp):
else: else:
hidden_states_to_dispatch = hidden_states hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch( dispatch_res = get_ep_group().dispatch_router_logits(
hidden_states_to_dispatch, hidden_states_to_dispatch,
router_logits, router_logits,
self.is_sequence_parallel, self.is_sequence_parallel,
......
...@@ -180,6 +180,7 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -180,6 +180,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> PrepareResultType: ) -> PrepareResultType:
""" """
Perform any quantization (and/or) dispatching needed for this kernel. Perform any quantization (and/or) dispatching needed for this kernel.
...@@ -192,6 +193,9 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -192,6 +193,9 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the - apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching. activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts. - quant_config: Quantization info provided by the fused experts.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
in cases where the compute kernel expects unquantized inputs
Returns a tuple of: Returns a tuple of:
- quantized + dispatched a. - quantized + dispatched a.
...@@ -220,6 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -220,6 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> tuple[Callable, ReceiverType] | ReceiverType: ) -> tuple[Callable, ReceiverType] | ReceiverType:
""" """
Perform any quantization (and/or) dispatching needed for this kernel Perform any quantization (and/or) dispatching needed for this kernel
...@@ -235,6 +240,9 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -235,6 +240,9 @@ class FusedMoEPrepareAndFinalize(ABC):
space to the local expert space of the expert parallel shard. space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the - apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching. activations, before quantization + dispatching.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
in cases where the compute kernel expects unquantized inputs
Returns a callback or a hook callback pair that when invoked waits for Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as results from other workers and has the same return signature as
...@@ -407,10 +415,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -407,10 +415,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
@staticmethod @property
def expects_unquantized_inputs( def expects_unquantized_inputs(self) -> bool:
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
""" """
Whether or not the PrepareFinalize should defer input quantization Whether or not the PrepareFinalize should defer input quantization
in the prepare step. If True, then the Experts kernel will in the prepare step. If True, then the Experts kernel will
...@@ -1069,6 +1075,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1069,6 +1075,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
) )
else: else:
# Overlap shared expert compute with all2all dispatch. # Overlap shared expert compute with all2all dispatch.
...@@ -1081,6 +1088,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1081,6 +1088,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
) )
# TODO(lucas): refactor this in the alternative schedules followup # TODO(lucas): refactor this in the alternative schedules followup
......
...@@ -3,70 +3,6 @@ ...@@ -3,70 +3,6 @@
import torch import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: torch.Tensor | None,
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: torch.Tensor | None,
block_m: int,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.size(1)
tokens_in_chunk = curr_hidden_states.size(0)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True
)
inv_perm: torch.Tensor | None = None
num_tokens = top_k_num * tokens_in_chunk
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num)
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, inv_perm)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: torch.Tensor | None,
topk_weight: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.size()
K = curr_hidden.size(-1)
if inv_perm is not None:
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
if not apply_router_weight_on_input:
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def moe_permute( def moe_permute(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -58,6 +58,7 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -58,6 +58,7 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
""" """
Returns a tuple of: Returns a tuple of:
...@@ -69,6 +70,11 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -69,6 +70,11 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
- Optional dispatched expert topk IDs - Optional dispatched expert topk IDs
- Optional dispatched expert topk weight - Optional dispatched expert topk weight
""" """
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"mori does not support apply_router_weight_on_input=True now." "mori does not support apply_router_weight_on_input=True now."
) )
......
...@@ -8,6 +8,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -8,6 +8,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -17,9 +20,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -17,9 +20,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm, is_supported_config_trtllm,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend, FlashinferMoeBackend,
get_flashinfer_moe_backend, get_flashinfer_moe_backend,
...@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( ...@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
) )
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -330,9 +331,16 @@ def select_fp8_moe_backend( ...@@ -330,9 +331,16 @@ def select_fp8_moe_backend(
else: else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local") logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError( # TODO(rob): per discussion with TPU team, we need a way to register
"No FP8 MoE backend supports the deployment configuration." # MoE backends by OOT plugins, rather than having an explicit list
) # of AVAILBLE_BACKENDS. Enabling returning `Fp8MoeBackend.NONE` is
# a temporary measure until these register APIs are complete.
if current_platform.is_cuda() or current_platform.is_rocm():
raise NotImplementedError(
"No FP8 MoE backend supports the deployment configuration."
)
return Fp8MoeBackend.NONE, None
def convert_to_fp8_moe_kernel_format( def convert_to_fp8_moe_kernel_format(
...@@ -457,68 +465,52 @@ def make_fp8_moe_quant_config( ...@@ -457,68 +465,52 @@ def make_fp8_moe_quant_config(
) )
def make_fp8_moe_kernel_for_mkm( def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
prepare_finalize: mk.FusedMoEPrepareAndFinalize, fp8_backend: Fp8MoeBackend,
) -> mk.FusedMoEPermuteExpertsUnpermute: routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> tuple[mk.FusedMoEModularKernel, bool]:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None assert max_num_tokens is not None
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=moe_quant_config,
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
) )
else: else:
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=moe_quant_config,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
fp8_backend: Fp8MoeBackend,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> tuple[mk.FusedMoEModularKernel, bool]:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"Fp8 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
) )
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# NOTE(rob): we only want the mk to control the shared_expert # NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in # if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class. # the new MoE runner class.
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=None, shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config, moe_parallel_config=moe_config.moe_parallel_config,
) )
......
...@@ -7,6 +7,9 @@ import torch ...@@ -7,6 +7,9 @@ import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -14,9 +17,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -14,9 +17,6 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config, nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config, nvfp4_w4a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_supported_config_trtllm, is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass, prepare_nvfp4_moe_layer_for_fi_or_cutlass,
...@@ -391,67 +391,51 @@ def make_nvfp4_moe_quant_config( ...@@ -391,67 +391,51 @@ def make_nvfp4_moe_quant_config(
) )
def make_nvfp4_moe_kernel_for_mkm( def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
prepare_finalize: mk.FusedMoEPrepareAndFinalize, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPermuteExpertsUnpermute: shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None assert max_num_tokens is not None
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=moe_quant_config,
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
) )
else: else:
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=moe_quant_config,
) )
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> mk.FusedMoEModularKernel:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"NvFP4 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# NOTE(rob): we only want the mk to control the shared_expert # NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in # if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class. # the new MoE runner class.
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=None, shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config, moe_parallel_config=moe_config.moe_parallel_config,
) )
......
...@@ -106,7 +106,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -106,7 +106,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]: ) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
num_tokens = a1.size(0) # M num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K hidden_dim = a1.size(-1) # K
...@@ -274,6 +281,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -274,6 +281,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
hook, receiver = self.prepare_async( hook, receiver = self.prepare_async(
a1, a1,
...@@ -283,6 +291,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -283,6 +291,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
quant_config, quant_config,
defer_input_quant=defer_input_quant,
) )
hook() hook()
return receiver() return receiver()
......
...@@ -4,18 +4,25 @@ ...@@ -4,18 +4,25 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_ep_group
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceContiguous,
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
) )
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
def __init__(self, defer_input_quant: bool = False) -> None: def __init__(
self,
is_sequence_parallel: bool = False,
num_dispatchers: int = 1,
) -> None:
super().__init__() super().__init__()
self.defer_input_quant = defer_input_quant self.is_sequence_parallel = is_sequence_parallel
self._num_dispatchers = num_dispatchers
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
...@@ -27,6 +34,113 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -27,6 +34,113 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def topk_indices_dtype(self) -> torch.dtype | None: def topk_indices_dtype(self) -> torch.dtype | None:
return None return None
def num_dispatchers(self) -> int:
return self._num_dispatchers
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quantization to the MoE kernel.
use_nvfp4 = quant_config.use_nvfp4_w4a4
if defer_input_quant:
a1q = a1
a1q_scale = None
else:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled=False,
)
# Skip gathering scales if we have static quantization
# (the scale is a scalar, replicated on all ranks) or
# if quantization is deferred.
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
scales = None if skip_gather_scales else [a1q_scale]
res = get_ep_group().dispatch(
a1q,
topk_weights,
topk_ids,
is_sequence_parallel=self.is_sequence_parallel,
extra_tensors=scales,
)
if skip_gather_scales:
a1q, topk_weights, topk_ids = res
else:
a1q, topk_weights, topk_ids, scales = res
assert scales is not None and len(scales) == 1
a1q_scale = scales[0]
if quant_config.quant_dtype == "nvfp4":
assert a1q_scale is not None
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
out = weight_and_reduce_impl.apply(
output=None,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
output.copy_(
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
)
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return 1 return 1
...@@ -42,6 +156,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -42,6 +156,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
...@@ -54,12 +169,17 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -54,12 +169,17 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
# Defer input quant to moe kernel for backends (e.g. AITER, FI) # Defer input quant to moe kernel for backends (e.g. AITER, FI)
# which use a single kernel call for quant + experts. # which use a single kernel call for quant + experts.
if self.defer_input_quant: if defer_input_quant:
return a1, None, None, None, None return a1, None, None, None, None
input_sf = (
quant_config.a1_gscale
if quant_config.use_nvfp4_w4a4
else quant_config.a1_scale
)
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
quant_config.a1_scale, input_sf,
quant_config.quant_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.per_act_token_quant,
quant_config.block_shape, quant_config.block_shape,
......
...@@ -287,17 +287,14 @@ def rocm_aiter_fused_experts( ...@@ -287,17 +287,14 @@ def rocm_aiter_fused_experts(
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
@staticmethod
def expects_unquantized_inputs(
fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
# AITER fused MoE kernels handle input quantization internally.
return True
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
return rocm_aiter_ops.is_fused_moe_enabled() return rocm_aiter_ops.is_fused_moe_enabled()
...@@ -329,7 +326,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -329,7 +326,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True return not moe_parallel_config.use_fi_all2allv_kernels
def supports_expert_map(self): def supports_expert_map(self):
return True return True
......
...@@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE): ...@@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE):
use_overlapped use_overlapped
and not ( and not (
(self.enable_eplb and backend != "allgather_reducescatter") (self.enable_eplb and backend != "allgather_reducescatter")
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) or self.moe_parallel_config.use_fi_all2allv_kernels
) )
and self._shared_experts is not None and self._shared_experts is not None
) )
......
...@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import (
sharded_weight_loader, sharded_weight_loader,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
...@@ -502,6 +503,9 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -502,6 +503,9 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1, dim=-1,
) )
# Check if running on Blackwell (SM100+) for kernel tuning
self.is_blackwell = current_platform.is_device_capability_family(100)
def forward_native( def forward_native(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -883,6 +887,7 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -883,6 +887,7 @@ class MambaMixer2(MambaBase, CustomOp):
state_batch_indices=state_indices_tensor_d_input, state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output, dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
is_blackwell=self.is_blackwell,
) )
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
......
...@@ -286,6 +286,7 @@ def selective_state_update( ...@@ -286,6 +286,7 @@ def selective_state_update(
out=None, out=None,
num_accepted_tokens=None, num_accepted_tokens=None,
cu_seqlens=None, cu_seqlens=None,
is_blackwell=False,
): ):
""" """
Argument: Argument:
...@@ -391,17 +392,26 @@ def selective_state_update( ...@@ -391,17 +392,26 @@ def selective_state_update(
if dst_state_batch_indices is not None if dst_state_batch_indices is not None
else (0, 0) else (0, 0)
) )
# We don't want autotune since it will overwrite the state # We don't want autotune since it will overwrite the state.
# We instead tune by hand. # We instead tune by hand based on dstate.
BLOCK_SIZE_M, num_warps = (
(32, 4) # Default
if dstate <= 16 BLOCK_SIZE_M, num_warps = 4, 8
else (
(16, 4) if dstate <= 16:
if dstate <= 32 BLOCK_SIZE_M, num_warps = 32, 4
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) elif dstate <= 32:
) BLOCK_SIZE_M, num_warps = 16, 4
) elif dstate <= 64:
BLOCK_SIZE_M, num_warps = 8, 4
else:
# dstate > 64
if is_blackwell:
# Optimized for B200 with dstate>64
BLOCK_SIZE_M, num_warps = 32, 8
elif dstate <= 128:
BLOCK_SIZE_M, num_warps = 4, 4
tie_hdim = ( tie_hdim = (
A.stride(-1) == 0 A.stride(-1) == 0
and A.stride(-2) == 0 and A.stride(-2) == 0
......
...@@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -53,7 +52,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( ...@@ -53,7 +52,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_mxfp4_moe_quant_config, make_mxfp4_moe_quant_config,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
) )
...@@ -67,7 +65,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( ...@@ -67,7 +65,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe, process_fp8_input_tensor_strategy_moe,
...@@ -244,7 +241,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -244,7 +241,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.group_size = 32 self.group_size = 32
self.mxfp4_backend = NvFp4MoeBackend.MARLIN self.mxfp4_backend = NvFp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts self.experts_cls = MarlinExperts
self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights( def create_weights(
self, self,
...@@ -321,7 +317,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -321,7 +317,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False layer.w13_weight_packed.data, requires_grad=False
) )
...@@ -336,10 +332,12 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -336,10 +332,12 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None: if self.moe_quant_config is not None:
self.kernel = make_nvfp4_moe_kernel( self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
def apply( def apply(
...@@ -349,8 +347,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -349,8 +347,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None assert self.moe_mk is not None
return self.kernel( return self.moe_mk(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -381,19 +379,10 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -381,19 +379,10 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
activation_key=None if use_a16 else kNvfp4Dynamic, activation_key=None if use_a16 else kNvfp4Dynamic,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -507,7 +496,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -507,7 +496,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
) )
set_weight_attrs(w2_input_scale, extra_weight_attrs) set_weight_attrs(w2_input_scale, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
""" """
Convert NVFP4 MoE weights into kernel format and setup the kernel. Convert NVFP4 MoE weights into kernel format and setup the kernel.
""" """
...@@ -573,48 +562,33 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -573,48 +562,33 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and ( if self.moe_quant_config:
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.kernel = make_nvfp4_moe_kernel( self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: raise ValueError(
return None f"{self.__class__.__name__} uses the new modular kernel initialization "
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: "logic. This function should not be called."
# For no-EP case, don't use the MKM framework. )
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None raise ValueError(
assert self.experts_cls is not None f"{self.__class__.__name__} uses the new modular kernel initialization "
return make_nvfp4_moe_kernel_for_mkm( "logic. This function should not be called."
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -685,8 +659,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -685,8 +659,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
) )
else: else:
assert self.kernel is not None assert self.moe_mk is not None
return self.kernel( return self.moe_mk(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -760,15 +734,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -760,15 +734,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
allow_vllm_cutlass=True, allow_vllm_cutlass=True,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -928,7 +893,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -928,7 +893,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
E=layer.w13_weight.shape[0] E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1] N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1] N2=layer.w2_weight.shape[1]
...@@ -947,7 +912,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -947,7 +912,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
pass pass
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def process_weights_after_loading(self, layer: FusedMoE) -> None:
# Allow for accessing weights and scales in standard way. # Allow for accessing weights and scales in standard way.
w13 = layer.w13_weight w13 = layer.w13_weight
w2 = layer.w2_weight w2 = layer.w2_weight
...@@ -1009,49 +975,34 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1009,49 +975,34 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and ( if self.moe_quant_config:
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel( self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: raise ValueError(
return None f"{self.__class__.__name__} uses the new modular kernel initialization "
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: "logic. This function should not be called."
# For no-EP case, don't use the MKM framework. )
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None raise ValueError(
assert self.experts_cls is not None f"{self.__class__.__name__} uses the new modular kernel initialization "
return make_fp8_moe_kernel_for_mkm( "logic. This function should not be called."
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -1142,8 +1093,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1142,8 +1093,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert not self.is_monolithic
assert self.kernel is not None assert self.moe_mk is not None
return self.kernel( return self.moe_mk(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
......
...@@ -16,6 +16,9 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ...@@ -16,6 +16,9 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported, cutlass_fp4_supported,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
...@@ -159,9 +162,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -159,9 +162,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
if self.backend == "fbgemm": if self.backend == "fbgemm":
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8) swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False # Pad weights for CUTLASS/FlashInfer kernel alignment (K and N
) # divisible by 32). fbgemm has its own layout requirements.
if self.backend in ("cutlass", "flashinfer-cutlass"):
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight_packed.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
layer.weights_padding_cols = 0
layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False
)
layer.alpha = Parameter( layer.alpha = Parameter(
1 / (layer.input_global_scale * layer.weight_global_scale), 1 / (layer.input_global_scale * layer.weight_global_scale),
...@@ -187,7 +201,8 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -187,7 +201,8 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
return out return out
output_dtype = x.dtype output_dtype = x.dtype
output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]] output_size = layer.output_size_per_partition
output_shape = [*x.shape[:-1], output_size]
# quantize BF16 or FP16 to (FP4 and interleaved block scale) # quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant( x_fp4, x_blockscale = scaled_fp4_quant(
...@@ -197,6 +212,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -197,6 +212,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
backend=self.backend, backend=self.backend,
) )
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = ( mm_args = (
x_fp4, x_fp4,
layer.weight_packed, layer.weight_packed,
...@@ -221,6 +240,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -221,6 +240,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
assert self.backend == "cutlass" assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args) out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.view(*output_shape) return out.view(*output_shape)
...@@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -53,7 +52,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( ...@@ -53,7 +52,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
...@@ -679,15 +677,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -679,15 +677,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_vllm_cutlass=False, allow_vllm_cutlass=False,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: Module, layer: Module,
...@@ -813,7 +802,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -813,7 +802,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel( def _setup_kernel(
self, self,
layer: Module, layer: FusedMoE,
w13: torch.Tensor, w13: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w13_scale: torch.Tensor, w13_scale: torch.Tensor,
...@@ -845,16 +834,15 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -845,16 +834,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and ( if self.moe_quant_config:
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel( self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
...@@ -909,33 +897,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -909,33 +897,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: raise ValueError(
return None f"{self.__class__.__name__} uses the new modular kernel initialization "
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: "logic. This function should not be called."
# For no-EP case, don't use the MKM framework. )
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None raise ValueError(
assert self.experts_cls is not None f"{self.__class__.__name__} uses the new modular kernel initialization "
return make_fp8_moe_kernel_for_mkm( "logic. This function should not be called."
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -1037,9 +1011,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1037,9 +1011,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None assert self.moe_mk is not None
assert not self.is_monolithic assert not self.is_monolithic
return self.kernel( return self.moe_mk(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
......
...@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -35,7 +34,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( ...@@ -35,7 +34,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
convert_to_nvfp4_moe_kernel_format, convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
) )
...@@ -54,13 +52,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( ...@@ -54,13 +52,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
) )
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe, flashinfer_trtllm_fp4_routed_moe,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
...@@ -84,6 +80,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -84,6 +80,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTokenSym, kFp8StaticTokenSym,
kNvfp4Dynamic, kNvfp4Dynamic,
kNvfp4Static, kNvfp4Static,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...@@ -736,47 +735,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -736,47 +735,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
activation_key=kFp8StaticTensorSym, activation_key=kFp8StaticTensorSym,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
# TRT LLM not supported with all2all yet. raise ValueError(
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: f"{self.__class__.__name__} uses the new modular kernel initialization "
return None "logic. This function should not be called."
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: )
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None raise ValueError(
assert self.experts_cls is not None f"{self.__class__.__name__} uses the new modular kernel initialization "
return make_fp8_moe_kernel_for_mkm( "logic. This function should not be called."
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def create_weights( def create_weights(
...@@ -860,7 +835,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -860,7 +835,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel( def _setup_kernel(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
w13: torch.Tensor, w13: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w13_scale: torch.Tensor, w13_scale: torch.Tensor,
...@@ -890,11 +865,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -890,11 +865,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config:
assert self.experts_cls is not None assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel( self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
...@@ -995,8 +972,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -995,8 +972,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
f"but got {layer.activation}" f"but got {layer.activation}"
) )
assert self.kernel is not None assert self.moe_mk is not None
return self.kernel( return self.moe_mk(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1280,9 +1257,16 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1280,9 +1257,16 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False)
else: else:
# Swizzle block scales and pad the packed NVFP4 weights for kernel
# alignment (CUTLASS/FlashInfer require K and N divisible by 32).
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight = Parameter(weight, requires_grad=False)
def apply( def apply(
self, self,
...@@ -1304,7 +1288,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1304,7 +1288,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
) )
output_dtype = x.dtype output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale) # quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant( x_fp4, x_blockscale = scaled_fp4_quant(
...@@ -1319,6 +1302,12 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1319,6 +1302,12 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert layer.weight_scale.dtype == torch.float8_e4m3fn assert layer.weight_scale.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32 assert layer.alpha.dtype == torch.float32
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
output_size = layer.output_size_per_partition
output_shape = [x.shape[0], output_size]
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = ( mm_args = (
x_fp4, x_fp4,
layer.weight, layer.weight,
...@@ -1327,6 +1316,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1327,6 +1316,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.alpha, layer.alpha,
output_dtype, output_dtype,
) )
if self.backend.startswith("flashinfer-"): if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :] backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
...@@ -1334,6 +1324,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1334,6 +1324,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert self.backend == "cutlass" assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args) out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.view(*output_shape) return out.view(*output_shape)
...@@ -1360,50 +1353,27 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1360,50 +1353,27 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation_key=kNvfp4Dynamic, activation_key=kNvfp4Dynamic,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: raise ValueError(
return None f"{self.__class__.__name__} uses the new modular kernel initialization "
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: "logic. This function should not be called."
# For no-EP case, don't use the MKM framework. )
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
self.moe
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None raise ValueError(
assert self.experts_cls is not None f"{self.__class__.__name__} uses the new modular kernel initialization "
return make_nvfp4_moe_kernel_for_mkm( "logic. This function should not be called."
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def uses_weight_scale_2_pattern(self) -> bool: def uses_weight_scale_2_pattern(self) -> bool:
...@@ -1528,7 +1498,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1528,7 +1498,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
) )
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
""" """
Convert NVFP4 MoE weights into kernel format and setup the kernel. Convert NVFP4 MoE weights into kernel format and setup the kernel.
""" """
...@@ -1580,15 +1550,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1580,15 +1550,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and ( if self.moe_quant_config:
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.kernel = make_nvfp4_moe_kernel( self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
@property @property
...@@ -1689,8 +1658,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1689,8 +1658,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
) )
else: else:
assert self.kernel is not None assert self.moe_mk is not None
return self.kernel( return self.moe_mk(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
......
...@@ -1053,32 +1053,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -1053,32 +1053,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1) x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
trtllm_gen_output = trtllm_fp4_block_scale_moe( trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16), routing_logits=router_logits.to(torch.bfloat16),
None, # routing_bias routing_bias=None,
x_quant, hidden_states=x_quant,
x_scale, hidden_states_scale=x_scale,
layer.w13_weight, # uint8 (e2m1 x 2) gemm1_weights=layer.w13_weight, # uint8 (e2m1 x 2)
layer.w13_weight_scale, # uint8 (e4m3 x 2) gemm1_weights_scale=layer.w13_weight_scale, # uint8 (e4m3 x 2)
layer.w13_bias, # fp32 per expert per channel gemm1_bias=layer.w13_bias, # fp32 per expert per channel
layer.gemm1_alpha, # fp32 per expert gemm1_alpha=layer.gemm1_alpha, # fp32 per expert
layer.gemm1_beta, # fp32 per expert gemm1_beta=layer.gemm1_beta, # fp32 per expert
layer.gemm1_clamp_limit, # fp32 per expert gemm1_clamp_limit=layer.gemm1_clamp_limit, # fp32 per expert
layer.w2_weight, # uint8 (e2m1 x 2) gemm2_weights=layer.w2_weight, # uint8 (e2m1 x 2)
layer.w2_weight_scale, # ue8m0 gemm2_weights_scale=layer.w2_weight_scale, # ue8m0
layer.w2_bias, # fp32 per expert per channel gemm2_bias=layer.w2_bias, # fp32 per expert per channel
None, # output1_scale_scalar output1_scale_scalar=None,
None, # output1_scale_gate_scalar output1_scale_gate_scalar=None,
None, # output2_scale_scalar output2_scale_scalar=None,
layer.global_num_experts, num_experts=layer.global_num_experts,
layer.top_k, top_k=layer.top_k,
None, # n_group n_group=None,
None, # topk_group topk_group=None,
self.intermediate_size, # padded to multiple of 256 intermediate_size=self.intermediate_size, # padded to multiple of 256
layer.ep_rank * layer.local_num_experts, # local_expert_offset local_expert_offset=layer.ep_rank * layer.local_num_experts,
self.num_experts, # local num experts local_num_experts=self.num_experts,
None, # routed_scaling_factor routed_scaling_factor=None,
1 if layer.renormalize else 0, # routing_method_type, renormalize routing_method_type=1 if layer.renormalize else 0,
True, # do finalize do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1), tune_max_num_tokens=max(self.max_capture_size, 1),
)[0] )[0]
return trtllm_gen_output return trtllm_gen_output
......
...@@ -15,9 +15,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -15,9 +15,6 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig, FusedMoEParallelConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kNvfp4Dynamic, kNvfp4Dynamic,
...@@ -42,7 +39,6 @@ __all__ = [ ...@@ -42,7 +39,6 @@ __all__ = [
"is_flashinfer_fp4_cutlass_moe_available", "is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available", "is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1", "reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
] ]
# #
...@@ -163,17 +159,6 @@ def reorder_w1w3_to_w3w1( ...@@ -163,17 +159,6 @@ def reorder_w1w3_to_w3w1(
) )
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv"
return create_flashinfer_prepare_finalize(
use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv
)
def prepare_static_weights_for_trtllm_fp4_moe( def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant, # args_dequant,
# args, # args,
......
...@@ -4,15 +4,8 @@ from enum import Enum ...@@ -4,15 +4,8 @@ from enum import Enum
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
...@@ -163,18 +156,6 @@ def make_fp8_moe_alpha_scales_for_fi( ...@@ -163,18 +156,6 @@ def make_fp8_moe_alpha_scales_for_fi(
return g1_alphas, g2_alphas return g1_alphas, g2_alphas
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
# Propagate block-scale flag so prepare/finalize can skip act quantization
# and inform the kernel to consume per-block weight scales.
return create_flashinfer_prepare_finalize(
use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
)
def get_flashinfer_moe_backend() -> FlashinferMoeBackend: def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = { backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS, "throughput": FlashinferMoeBackend.CUTLASS,
......
...@@ -868,3 +868,70 @@ def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tens ...@@ -868,3 +868,70 @@ def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tens
t |= ((nib - 8) & 0xF) << shift t |= ((nib - 8) & 0xF) << shift
return t return t
def round_up(x: int, m: int) -> int:
"""Round up x to the nearest multiple of m."""
return (x + m - 1) // m * m
def pad_nvfp4_weight_for_cutlass(
weight: torch.Tensor,
alignment: int = 32,
) -> tuple[torch.Tensor, int]:
"""
Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy
the alignment constraints required by CUTLASS / FlashInfer FP4 kernels.
CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible
by 32 for aligned memory access and efficient tensor core operations.
"""
weight_current_rows = weight.shape[0]
# Pad N dimension (rows) if not aligned
if weight_current_rows % alignment != 0:
total_rows = round_up(weight_current_rows, alignment)
pad_rows = total_rows - weight_current_rows
weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_rows)).contiguous()
# Check K dimension alignment
# 2 FP4 items are packed per byte in the input dimension
weight_current_col_bytes = weight.shape[1]
weight_current_col_elements = weight_current_col_bytes * 2
weights_padding_bytes = 0
if weight_current_col_elements % alignment != 0:
total_cols = round_up(weight_current_col_elements, alignment)
pad_cols = total_cols - weight_current_col_elements
# Convert from FP4 element count to bytes (2 FP4 values per byte)
# pad_cols is always even since alignment=32 and current elements are even
pad_bytes = pad_cols // 2
weight = torch.nn.functional.pad(weight, (0, pad_bytes, 0, 0)).contiguous()
weights_padding_bytes = pad_bytes
return weight, weights_padding_bytes
def pad_nvfp4_activation_for_cutlass(
x_fp4: torch.Tensor,
weights_padding_bytes: int,
) -> torch.Tensor:
"""
Pad packed FP4 activations to match the K-dimension padding applied to weights.
The padding is in bytes (tensor dimension), not FP4 elements.
"""
if weights_padding_bytes > 0:
return torch.nn.functional.pad(x_fp4, (0, weights_padding_bytes)).contiguous()
return x_fp4
def slice_nvfp4_output(
out: torch.Tensor,
output_size: int,
) -> torch.Tensor:
"""
Slice the output tensor to remove padding in N dimension if weight was padded.
"""
if out.shape[-1] != output_size:
return out[..., :output_size].contiguous()
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment