Commit c721b814 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1

parent d53fe7e5
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
...@@ -573,6 +572,9 @@ class FusedMoE(CustomOp): ...@@ -573,6 +572,9 @@ class FusedMoE(CustomOp):
device=vllm_config.device_config.device, device=vllm_config.device_config.device,
routing_method=self.routing_method_type, routing_method=self.routing_method_type,
) )
self.moe_config_use_flashinfer_cutlass_kernels = (
self.moe_config.use_flashinfer_cutlass_kernels
)
if self.use_mori_kernels: if self.use_mori_kernels:
assert self.rocm_aiter_fmoe_enabled, ( assert self.rocm_aiter_fmoe_enabled, (
"Mori needs to be used with aiter fused_moe for now." "Mori needs to be used with aiter fused_moe for now."
...@@ -645,10 +647,6 @@ class FusedMoE(CustomOp): ...@@ -645,10 +647,6 @@ class FusedMoE(CustomOp):
# This is called after all weight loading and post-processing, so it # This is called after all weight loading and post-processing, so it
# should be safe to swap out the quant_method. # should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None: def maybe_init_modular_kernel(self) -> None:
# NOTE(rob): WIP refactor. For quant methods that own the MK
# we create the MK during process_weights_after_loading.
if self.quant_method.supports_internal_mk or self.quant_method.is_monolithic:
return None
self.ensure_moe_quant_config_init() self.ensure_moe_quant_config_init()
# routing_tables only needed for round-robin expert placement with # routing_tables only needed for round-robin expert placement with
...@@ -732,6 +730,14 @@ class FusedMoE(CustomOp): ...@@ -732,6 +730,14 @@ class FusedMoE(CustomOp):
def use_mori_kernels(self): def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels return self.moe_parallel_config.use_mori_kernels
@property
def use_flashinfer_cutlass_kernels(self):
return (
self.moe_quant_config is not None
and self.moe_quant_config.quant_dtype == "nvfp4"
and self.moe_config_use_flashinfer_cutlass_kernels
)
@property @property
def use_marlin_kernels(self): def use_marlin_kernels(self):
return getattr(self.quant_method, "use_marlin", False) return getattr(self.quant_method, "use_marlin", False)
...@@ -742,7 +748,7 @@ class FusedMoE(CustomOp): ...@@ -742,7 +748,7 @@ class FusedMoE(CustomOp):
self.moe_parallel_config.use_pplx_kernels self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels or self.moe_parallel_config.use_mori_kernels
or self.moe_parallel_config.use_fi_all2allv_kernels or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
) and envs.VLLM_ENABLE_MOE_DP_CHUNK ) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@property @property
...@@ -1528,7 +1534,7 @@ class FusedMoE(CustomOp): ...@@ -1528,7 +1534,7 @@ class FusedMoE(CustomOp):
assert self.quant_method is not None assert self.quant_method is not None
return ( return (
isinstance(self.quant_method, FusedMoEModularMethod) isinstance(self.quant_method, FusedMoEModularMethod)
and self.quant_method.moe_mk.output_is_reduced() # type: ignore[union-attr] and self.quant_method.fused_experts.output_is_reduced()
) )
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
...@@ -1761,7 +1767,7 @@ class FusedMoE(CustomOp): ...@@ -1761,7 +1767,7 @@ class FusedMoE(CustomOp):
self.ensure_dp_chunking_init() self.ensure_dp_chunking_init()
has_separate_shared_experts = ( has_separate_shared_experts = (
not self.quant_method.mk_owns_shared_expert not isinstance(self.quant_method, FusedMoEModularMethod)
and self.shared_experts is not None and self.shared_experts is not None
) )
...@@ -1785,10 +1791,8 @@ class FusedMoE(CustomOp): ...@@ -1785,10 +1791,8 @@ class FusedMoE(CustomOp):
hidden_states, router_logits, has_separate_shared_experts hidden_states, router_logits, has_separate_shared_experts
) )
# NOTE(rob): once we finish migrating all the quant methods to use do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance(
# MKs, we can remove the naive dispatch/combine path from here. self.quant_method, FusedMoEModularMethod
do_naive_dispatch_combine = (
self.dp_size > 1 and not self.quant_method.supports_internal_mk
) )
ctx = get_forward_context() ctx = get_forward_context()
...@@ -1816,7 +1820,7 @@ class FusedMoE(CustomOp): ...@@ -1816,7 +1820,7 @@ class FusedMoE(CustomOp):
else: else:
hidden_states_to_dispatch = hidden_states hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch_router_logits( dispatch_res = get_ep_group().dispatch(
hidden_states_to_dispatch, hidden_states_to_dispatch,
router_logits, router_logits,
self.is_sequence_parallel, self.is_sequence_parallel,
......
...@@ -180,7 +180,6 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -180,7 +180,6 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> PrepareResultType: ) -> PrepareResultType:
""" """
Perform any quantization (and/or) dispatching needed for this kernel. Perform any quantization (and/or) dispatching needed for this kernel.
...@@ -193,9 +192,6 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -193,9 +192,6 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the - apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching. activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts. - quant_config: Quantization info provided by the fused experts.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
in cases where the compute kernel expects unquantized inputs
Returns a tuple of: Returns a tuple of:
- quantized + dispatched a. - quantized + dispatched a.
...@@ -224,7 +220,6 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -224,7 +220,6 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> tuple[Callable, ReceiverType] | ReceiverType: ) -> tuple[Callable, ReceiverType] | ReceiverType:
""" """
Perform any quantization (and/or) dispatching needed for this kernel Perform any quantization (and/or) dispatching needed for this kernel
...@@ -240,9 +235,6 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -240,9 +235,6 @@ class FusedMoEPrepareAndFinalize(ABC):
space to the local expert space of the expert parallel shard. space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the - apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching. activations, before quantization + dispatching.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
in cases where the compute kernel expects unquantized inputs
Returns a callback or a hook callback pair that when invoked waits for Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as results from other workers and has the same return signature as
...@@ -415,8 +407,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -415,8 +407,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
@property @staticmethod
def expects_unquantized_inputs(self) -> bool: def expects_unquantized_inputs(
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
""" """
Whether or not the PrepareFinalize should defer input quantization Whether or not the PrepareFinalize should defer input quantization
in the prepare step. If True, then the Experts kernel will in the prepare step. If True, then the Experts kernel will
......
...@@ -3,6 +3,70 @@ ...@@ -3,6 +3,70 @@
import torch import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: torch.Tensor | None,
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: torch.Tensor | None,
block_m: int,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.size(1)
tokens_in_chunk = curr_hidden_states.size(0)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True
)
inv_perm: torch.Tensor | None = None
num_tokens = top_k_num * tokens_in_chunk
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num)
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, inv_perm)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: torch.Tensor | None,
topk_weight: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.size()
K = curr_hidden.size(-1)
if inv_perm is not None:
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
if not apply_router_weight_on_input:
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def moe_permute( def moe_permute(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -162,4 +226,4 @@ def moe_unpermute( ...@@ -162,4 +226,4 @@ def moe_unpermute(
def moe_permute_unpermute_supported(): def moe_permute_unpermute_supported():
return torch.ops._moe_C.moe_permute_unpermute_supported() return torch.ops._moe_C.moe_permute_unpermute_supported()
\ No newline at end of file
...@@ -58,7 +58,6 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -58,7 +58,6 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
""" """
Returns a tuple of: Returns a tuple of:
...@@ -70,11 +69,6 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -70,11 +69,6 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
- Optional dispatched expert topk IDs - Optional dispatched expert topk IDs
- Optional dispatched expert topk weight - Optional dispatched expert topk weight
""" """
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"mori does not support apply_router_weight_on_input=True now." "mori does not support apply_router_weight_on_input=True now."
) )
......
...@@ -8,9 +8,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -8,9 +8,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -20,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -20,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_fp8, is_supported_config_trtllm_fp8,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend, FlashinferMoeBackend,
get_flashinfer_moe_backend, get_flashinfer_moe_backend,
...@@ -35,7 +35,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( ...@@ -35,7 +35,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
) )
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -331,14 +330,9 @@ def select_fp8_moe_backend( ...@@ -331,14 +330,9 @@ def select_fp8_moe_backend(
else: else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local") logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
# TODO(rob): per discussion with TPU team, we need a way to register raise NotImplementedError(
# MoE backends by OOT plugins, rather than having an explicit list "No FP8 MoE backend supports the deployment configuration."
# of AVAILBLE_BACKENDS. Enabling returning `Fp8MoeBackend.NONE` is )
# a temporary measure until these register APIs are complete.
if current_platform.is_cuda() or current_platform.is_rocm():
raise NotImplementedError(
"No FP8 MoE backend supports the deployment configuration."
)
return Fp8MoeBackend.NONE, None return Fp8MoeBackend.NONE, None
...@@ -465,55 +459,71 @@ def make_fp8_moe_quant_config( ...@@ -465,55 +459,71 @@ def make_fp8_moe_quant_config(
) )
def make_fp8_moe_kernel( def make_fp8_moe_kernel_for_mkm(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
fp8_backend: Fp8MoeBackend, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPermuteExpertsUnpermute:
shared_experts: torch.nn.Module | None = None,
) -> tuple[mk.FusedMoEModularKernel, bool]:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank() max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None assert max_num_tokens_per_rank is not None
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=moe_quant_config, quant_config=quant_config,
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
) )
else: else:
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=moe_quant_config, quant_config=quant_config,
) )
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
fp8_backend: Fp8MoeBackend,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> tuple[mk.FusedMoEModularKernel, bool]:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"Fp8 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# NOTE(rob): we only want the mk to control the shared_expert # NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in # if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class. # the new MoE runner class.
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=( shared_experts=None,
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config, moe_parallel_config=moe_config.moe_parallel_config,
) )
# TODO(rob): update inplace logic to be part of the kernel. # TODO(rob): update inplace logic to be part of the kernel.
inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
return kernel, inplace return kernel, inplace
\ No newline at end of file
...@@ -7,9 +7,6 @@ import torch ...@@ -7,9 +7,6 @@ import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -17,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -17,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config, nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config, nvfp4_w4a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_supported_config_trtllm, is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass, prepare_nvfp4_moe_layer_for_fi_or_cutlass,
...@@ -391,53 +391,69 @@ def make_nvfp4_moe_quant_config( ...@@ -391,53 +391,69 @@ def make_nvfp4_moe_quant_config(
) )
def make_nvfp4_moe_kernel( def make_nvfp4_moe_kernel_for_mkm(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
shared_experts: torch.nn.Module | None = None, ) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEModularKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank() max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None assert max_num_tokens_per_rank is not None
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=moe_quant_config, quant_config=quant_config,
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
) )
else: else:
experts = experts_cls( experts = experts_cls(
moe_config=moe_config, moe_config=moe_config,
quant_config=moe_quant_config, quant_config=quant_config,
) )
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> mk.FusedMoEModularKernel:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"NvFP4 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# NOTE(rob): we only want the mk to control the shared_expert # NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in # if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class. # the new MoE runner class.
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=( shared_experts=None,
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config, moe_parallel_config=moe_config.moe_parallel_config,
) )
# TODO(rob): update inplace logic to be part of the kernel. # TODO(rob): update inplace logic to be part of the kernel.
return kernel return kernel
\ No newline at end of file
...@@ -106,13 +106,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -106,13 +106,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]: ) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
num_tokens = a1.size(0) # M num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K hidden_dim = a1.size(-1) # K
...@@ -281,7 +275,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -281,7 +275,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
hook, receiver = self.prepare_async( hook, receiver = self.prepare_async(
a1, a1,
...@@ -291,7 +284,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -291,7 +284,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
quant_config, quant_config,
defer_input_quant=defer_input_quant,
) )
hook() hook()
return receiver() return receiver()
......
...@@ -4,133 +4,19 @@ ...@@ -4,133 +4,19 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_ep_group
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceContiguous,
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
) )
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def __init__( def __init__(self, defer_input_quant: bool = False) -> None:
self,
is_sequence_parallel: bool = False,
num_dispatchers: int = 1,
) -> None:
super().__init__() super().__init__()
self.is_sequence_parallel = is_sequence_parallel self.defer_input_quant = defer_input_quant
self._num_dispatchers = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return self._num_dispatchers
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quantization to the MoE kernel.
use_nvfp4 = quant_config.use_nvfp4_w4a4
if defer_input_quant:
a1q = a1
a1q_scale = None
else:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled=False,
)
# Skip gathering scales if we have static quantization
# (the scale is a scalar, replicated on all ranks) or
# if quantization is deferred.
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
scales = None if skip_gather_scales else [a1q_scale]
res = get_ep_group().dispatch(
a1q,
topk_weights,
topk_ids,
is_sequence_parallel=self.is_sequence_parallel,
extra_tensors=scales,
)
if skip_gather_scales:
a1q, topk_weights, topk_ids = res
else:
a1q, topk_weights, topk_ids, scales = res
assert scales is not None and len(scales) == 1
a1q_scale = scales[0]
if quant_config.quant_dtype == "nvfp4":
assert a1q_scale is not None
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
out = weight_and_reduce_impl.apply(
output=None,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
output.copy_(
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
)
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
...@@ -156,7 +42,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -156,7 +42,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
...@@ -169,17 +54,12 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -169,17 +54,12 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
# Defer input quant to moe kernel for backends (e.g. AITER, FI) # Defer input quant to moe kernel for backends (e.g. AITER, FI)
# which use a single kernel call for quant + experts. # which use a single kernel call for quant + experts.
if defer_input_quant: if self.defer_input_quant:
return a1, None, None, None, None return a1, None, None, None, None
input_sf = (
quant_config.a1_gscale
if quant_config.use_nvfp4_w4a4
else quant_config.a1_scale
)
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
input_sf, quant_config.a1_scale,
quant_config.quant_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.per_act_token_quant,
quant_config.block_shape, quant_config.block_shape,
...@@ -204,4 +84,4 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -204,4 +84,4 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
\ No newline at end of file
...@@ -287,14 +287,18 @@ def rocm_aiter_fused_experts( ...@@ -287,14 +287,18 @@ def rocm_aiter_fused_experts(
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
@staticmethod
def expects_unquantized_inputs(
fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
# AITER fused MoE kernels handle input quantization internally.
return True
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
return rocm_aiter_ops.is_fused_moe_enabled() return rocm_aiter_ops.is_fused_moe_enabled()
...@@ -326,7 +330,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -326,7 +330,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not moe_parallel_config.use_fi_all2allv_kernels return True
def supports_expert_map(self): def supports_expert_map(self):
return True return True
......
...@@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE): ...@@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE):
use_overlapped use_overlapped
and not ( and not (
(self.enable_eplb and backend != "allgather_reducescatter") (self.enable_eplb and backend != "allgather_reducescatter")
or self.moe_parallel_config.use_fi_all2allv_kernels or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
) )
and self._shared_experts is not None and self._shared_experts is not None
) )
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import (
sharded_weight_loader, sharded_weight_loader,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
...@@ -503,9 +502,6 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -503,9 +502,6 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1, dim=-1,
) )
# Check if running on Blackwell (SM100+) for kernel tuning
self.is_blackwell = current_platform.is_device_capability_family(100)
def forward_native( def forward_native(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -286,7 +286,6 @@ def selective_state_update( ...@@ -286,7 +286,6 @@ def selective_state_update(
out=None, out=None,
num_accepted_tokens=None, num_accepted_tokens=None,
cu_seqlens=None, cu_seqlens=None,
is_blackwell=False,
): ):
""" """
Argument: Argument:
...@@ -392,26 +391,17 @@ def selective_state_update( ...@@ -392,26 +391,17 @@ def selective_state_update(
if dst_state_batch_indices is not None if dst_state_batch_indices is not None
else (0, 0) else (0, 0)
) )
# We don't want autotune since it will overwrite the state. # We don't want autotune since it will overwrite the state
# We instead tune by hand based on dstate. # We instead tune by hand.
BLOCK_SIZE_M, num_warps = (
# Default (32, 4)
BLOCK_SIZE_M, num_warps = 4, 8 if dstate <= 16
else (
if dstate <= 16: (16, 4)
BLOCK_SIZE_M, num_warps = 32, 4 if dstate <= 32
elif dstate <= 32: else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
BLOCK_SIZE_M, num_warps = 16, 4 )
elif dstate <= 64: )
BLOCK_SIZE_M, num_warps = 8, 4
else:
# dstate > 64
if is_blackwell:
# Optimized for B200 with dstate>64
BLOCK_SIZE_M, num_warps = 32, 8
elif dstate <= 128:
BLOCK_SIZE_M, num_warps = 4, 4
tie_hdim = ( tie_hdim = (
A.stride(-1) == 0 A.stride(-1) == 0
and A.stride(-2) == 0 and A.stride(-2) == 0
...@@ -593,4 +583,4 @@ def selective_scan_fn( ...@@ -593,4 +583,4 @@ def selective_scan_fn(
if z is None: if z is None:
return delta # output written inplace to delta return delta # output written inplace to delta
else: else:
return z # output written inplace to z return z # output written inplace to z
\ No newline at end of file
...@@ -166,10 +166,6 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -166,10 +166,6 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if llama_4_scaling is not None: if llama_4_scaling is not None:
q *= llama_4_scaling q *= llama_4_scaling
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions,
self.rotary_emb)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q, q,
kv_c_normed, kv_c_normed,
......
...@@ -42,6 +42,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -42,6 +42,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -51,6 +52,7 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( ...@@ -51,6 +52,7 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_mxfp4_moe_quant_config, make_mxfp4_moe_quant_config,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
) )
...@@ -64,6 +66,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( ...@@ -64,6 +66,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe, process_fp8_input_tensor_strategy_moe,
...@@ -239,6 +242,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -239,6 +242,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.group_size = 32 self.group_size = 32
self.mxfp4_backend = NvFp4MoeBackend.MARLIN self.mxfp4_backend = NvFp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts self.experts_cls = MarlinExperts
self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights( def create_weights(
self, self,
...@@ -315,7 +319,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -315,7 +319,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
) )
def process_weights_after_loading(self, layer: FusedMoE) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False layer.w13_weight_packed.data, requires_grad=False
) )
...@@ -330,12 +334,10 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -330,12 +334,10 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None: if self.moe_quant_config is not None:
self.moe_mk = make_nvfp4_moe_kernel( self.kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
def apply( def apply(
...@@ -345,8 +347,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -345,8 +347,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None assert self.kernel is not None
return self.moe_mk( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -377,10 +379,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -377,10 +379,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
activation_key=None if use_a16 else kNvfp4Dynamic, activation_key=None if use_a16 else kNvfp4Dynamic,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -494,7 +505,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -494,7 +505,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
) )
set_weight_attrs(w2_input_scale, extra_weight_attrs) set_weight_attrs(w2_input_scale, extra_weight_attrs)
def process_weights_after_loading(self, layer: FusedMoE) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
""" """
Convert NVFP4 MoE weights into kernel format and setup the kernel. Convert NVFP4 MoE weights into kernel format and setup the kernel.
""" """
...@@ -560,33 +571,48 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -560,33 +571,48 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk = make_nvfp4_moe_kernel( self.kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError( if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
f"{self.__class__.__name__} uses the new modular kernel initialization " return None
"logic. This function should not be called." elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
) # For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError( assert self.moe_quant_config is not None
f"{self.__class__.__name__} uses the new modular kernel initialization " assert self.experts_cls is not None
"logic. This function should not be called." return make_nvfp4_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -657,8 +683,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -657,8 +683,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
) )
else: else:
assert self.moe_mk is not None assert self.kernel is not None
return self.moe_mk( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -732,6 +758,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -732,6 +758,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
allow_vllm_cutlass=True, allow_vllm_cutlass=True,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -891,7 +926,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -891,7 +926,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: FusedMoE) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Allow for accessing weights and scales in standard way. # Allow for accessing weights and scales in standard way.
w13 = layer.w13_weight w13 = layer.w13_weight
w2 = layer.w2_weight w2 = layer.w2_weight
...@@ -953,34 +988,49 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -953,34 +988,49 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel( self.kernel, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError( if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
f"{self.__class__.__name__} uses the new modular kernel initialization " return None
"logic. This function should not be called." elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
) # For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
raise ValueError( assert self.moe_quant_config is not None
f"{self.__class__.__name__} uses the new modular kernel initialization " assert self.experts_cls is not None
"logic. This function should not be called." return make_fp8_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -1061,8 +1111,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1061,8 +1111,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert not self.is_monolithic
assert self.moe_mk is not None assert self.kernel is not None
return self.moe_mk( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -2452,4 +2502,4 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2452,4 +2502,4 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
@property @property
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
return False return False
\ No newline at end of file
...@@ -16,9 +16,6 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ...@@ -16,9 +16,6 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported, cutlass_fp4_supported,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
...@@ -162,20 +159,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -162,20 +159,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
if self.backend == "fbgemm": if self.backend == "fbgemm":
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8) swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
layer.weight_packed = Parameter(
# Pad weights for CUTLASS/FlashInfer kernel alignment (K and N layer.weight_packed.data, requires_grad=False
# divisible by 32). fbgemm has its own layout requirements. )
if self.backend in ("cutlass", "flashinfer-cutlass"):
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight_packed.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
layer.weights_padding_cols = 0
layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False
)
layer.alpha = Parameter( layer.alpha = Parameter(
1 / (layer.input_global_scale * layer.weight_global_scale), 1 / (layer.input_global_scale * layer.weight_global_scale),
...@@ -201,8 +187,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -201,8 +187,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
return out return out
output_dtype = x.dtype output_dtype = x.dtype
output_size = layer.output_size_per_partition output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]]
output_shape = [*x.shape[:-1], output_size]
# quantize BF16 or FP16 to (FP4 and interleaved block scale) # quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant( x_fp4, x_blockscale = scaled_fp4_quant(
...@@ -212,10 +197,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -212,10 +197,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
backend=self.backend, backend=self.backend,
) )
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = ( mm_args = (
x_fp4, x_fp4,
layer.weight_packed, layer.weight_packed,
...@@ -240,9 +221,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -240,9 +221,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
assert self.backend == "cutlass" assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args) out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.view(*output_shape) return out.view(*output_shape)
\ No newline at end of file
...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -51,6 +52,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( ...@@ -51,6 +52,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
...@@ -676,6 +678,15 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -676,6 +678,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_vllm_cutlass=False, allow_vllm_cutlass=False,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: Module, layer: Module,
...@@ -801,7 +812,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -801,7 +812,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel( def _setup_kernel(
self, self,
layer: FusedMoE, layer: Module,
w13: torch.Tensor, w13: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w13_scale: torch.Tensor, w13_scale: torch.Tensor,
...@@ -833,15 +844,16 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -833,15 +844,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel( self.kernel, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
...@@ -896,19 +908,33 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -896,19 +908,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError( if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
f"{self.__class__.__name__} uses the new modular kernel initialization " return None
"logic. This function should not be called." elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
) # For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
raise ValueError( assert self.moe_quant_config is not None
f"{self.__class__.__name__} uses the new modular kernel initialization " assert self.experts_cls is not None
"logic. This function should not be called." return make_fp8_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -1002,9 +1028,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1002,9 +1028,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None assert self.kernel is not None
assert not self.is_monolithic assert not self.is_monolithic
return self.moe_mk( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1164,4 +1190,4 @@ class Fp8KVCacheMethod(BaseKVCacheMethod): ...@@ -1164,4 +1190,4 @@ class Fp8KVCacheMethod(BaseKVCacheMethod):
""" """
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config) super().__init__(quant_config)
\ No newline at end of file
...@@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -34,6 +35,7 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( ...@@ -34,6 +35,7 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
convert_to_nvfp4_moe_kernel_format, convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
) )
...@@ -52,11 +54,13 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( ...@@ -52,11 +54,13 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
) )
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe, flashinfer_trtllm_fp4_routed_moe,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
...@@ -80,9 +84,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -80,9 +84,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTokenSym, kFp8StaticTokenSym,
kNvfp4Dynamic, kNvfp4Dynamic,
kNvfp4Static, kNvfp4Static,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...@@ -735,23 +736,47 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -735,23 +736,47 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
activation_key=kFp8StaticTensorSym, activation_key=kFp8StaticTensorSym,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError( # TRT LLM not supported with all2all yet.
f"{self.__class__.__name__} uses the new modular kernel initialization " if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
"logic. This function should not be called." return None
) elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError( assert self.moe_quant_config is not None
f"{self.__class__.__name__} uses the new modular kernel initialization " assert self.experts_cls is not None
"logic. This function should not be called." return make_fp8_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def create_weights( def create_weights(
...@@ -835,7 +860,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -835,7 +860,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel( def _setup_kernel(
self, self,
layer: FusedMoE, layer: torch.nn.Module,
w13: torch.Tensor, w13: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w13_scale: torch.Tensor, w13_scale: torch.Tensor,
...@@ -865,13 +890,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -865,13 +890,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config:
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel( self.kernel, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
...@@ -972,8 +995,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -972,8 +995,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
f"but got {layer.activation}" f"but got {layer.activation}"
) )
assert self.moe_mk is not None assert self.kernel is not None
return self.moe_mk( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1257,16 +1280,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1257,16 +1280,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False)
else: else:
# Swizzle block scales and pad the packed NVFP4 weights for kernel
# alignment (CUTLASS/FlashInfer require K and N divisible by 32).
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight = Parameter(weight, requires_grad=False)
def apply( def apply(
self, self,
...@@ -1288,6 +1304,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1288,6 +1304,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
) )
output_dtype = x.dtype output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale) # quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant( x_fp4, x_blockscale = scaled_fp4_quant(
...@@ -1302,12 +1319,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1302,12 +1319,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert layer.weight_scale.dtype == torch.float8_e4m3fn assert layer.weight_scale.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32 assert layer.alpha.dtype == torch.float32
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
output_size = layer.output_size_per_partition
output_shape = [x.shape[0], output_size]
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = ( mm_args = (
x_fp4, x_fp4,
layer.weight, layer.weight,
...@@ -1316,7 +1327,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1316,7 +1327,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.alpha, layer.alpha,
output_dtype, output_dtype,
) )
if self.backend.startswith("flashinfer-"): if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :] backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
...@@ -1324,9 +1334,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1324,9 +1334,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert self.backend == "cutlass" assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args) out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.view(*output_shape) return out.view(*output_shape)
...@@ -1353,27 +1360,50 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1353,27 +1360,50 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation_key=kNvfp4Dynamic, activation_key=kNvfp4Dynamic,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError( if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
f"{self.__class__.__name__} uses the new modular kernel initialization " return None
"logic. This function should not be called." elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
) # For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
self.moe
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError( assert self.moe_quant_config is not None
f"{self.__class__.__name__} uses the new modular kernel initialization " assert self.experts_cls is not None
"logic. This function should not be called." return make_nvfp4_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def uses_weight_scale_2_pattern(self) -> bool: def uses_weight_scale_2_pattern(self) -> bool:
...@@ -1498,7 +1528,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1498,7 +1528,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
) )
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: FusedMoE) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
""" """
Convert NVFP4 MoE weights into kernel format and setup the kernel. Convert NVFP4 MoE weights into kernel format and setup the kernel.
""" """
...@@ -1550,14 +1580,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1550,14 +1580,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk = make_nvfp4_moe_kernel( self.kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
@property @property
...@@ -1658,8 +1689,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1658,8 +1689,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
) )
else: else:
assert self.moe_mk is not None assert self.kernel is not None
return self.moe_mk( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1675,4 +1706,4 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1675,4 +1706,4 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
\ No newline at end of file
...@@ -1053,32 +1053,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -1053,32 +1053,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1) x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
trtllm_gen_output = trtllm_fp4_block_scale_moe( trtllm_gen_output = trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16), router_logits.to(torch.bfloat16),
routing_bias=None, None, # routing_bias
hidden_states=x_quant, x_quant,
hidden_states_scale=x_scale, x_scale,
gemm1_weights=layer.w13_weight, # uint8 (e2m1 x 2) layer.w13_weight, # uint8 (e2m1 x 2)
gemm1_weights_scale=layer.w13_weight_scale, # uint8 (e4m3 x 2) layer.w13_weight_scale, # uint8 (e4m3 x 2)
gemm1_bias=layer.w13_bias, # fp32 per expert per channel layer.w13_bias, # fp32 per expert per channel
gemm1_alpha=layer.gemm1_alpha, # fp32 per expert layer.gemm1_alpha, # fp32 per expert
gemm1_beta=layer.gemm1_beta, # fp32 per expert layer.gemm1_beta, # fp32 per expert
gemm1_clamp_limit=layer.gemm1_clamp_limit, # fp32 per expert layer.gemm1_clamp_limit, # fp32 per expert
gemm2_weights=layer.w2_weight, # uint8 (e2m1 x 2) layer.w2_weight, # uint8 (e2m1 x 2)
gemm2_weights_scale=layer.w2_weight_scale, # ue8m0 layer.w2_weight_scale, # ue8m0
gemm2_bias=layer.w2_bias, # fp32 per expert per channel layer.w2_bias, # fp32 per expert per channel
output1_scale_scalar=None, None, # output1_scale_scalar
output1_scale_gate_scalar=None, None, # output1_scale_gate_scalar
output2_scale_scalar=None, None, # output2_scale_scalar
num_experts=layer.global_num_experts, layer.global_num_experts,
top_k=layer.top_k, layer.top_k,
n_group=None, None, # n_group
topk_group=None, None, # topk_group
intermediate_size=self.intermediate_size, # padded to multiple of 256 self.intermediate_size, # padded to multiple of 256
local_expert_offset=layer.ep_rank * layer.local_num_experts, layer.ep_rank * layer.local_num_experts, # local_expert_offset
local_num_experts=self.num_experts, self.num_experts, # local num experts
routed_scaling_factor=None, None, # routed_scaling_factor
routing_method_type=1 if layer.renormalize else 0, 1 if layer.renormalize else 0, # routing_method_type, renormalize
do_finalize=True, True, # do finalize
tune_max_num_tokens=max(self.max_capture_size, 1), tune_max_num_tokens=max(self.max_capture_size, 1),
)[0] )[0]
return trtllm_gen_output return trtllm_gen_output
...@@ -1170,4 +1170,4 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod): ...@@ -1170,4 +1170,4 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
activation="swiglu_oai", activation="swiglu_oai",
) )
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous() hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
return hidden_states return hidden_states
\ No newline at end of file
...@@ -14,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -14,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig, FusedMoEParallelConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kNvfp4Dynamic, kNvfp4Dynamic,
...@@ -32,6 +35,7 @@ logger = init_logger(__name__) ...@@ -32,6 +35,7 @@ logger = init_logger(__name__)
__all__ = [ __all__ = [
"reorder_w1w3_to_w3w1", "reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
] ]
# #
...@@ -132,6 +136,17 @@ def reorder_w1w3_to_w3w1( ...@@ -132,6 +136,17 @@ def reorder_w1w3_to_w3w1(
) )
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv"
return create_flashinfer_prepare_finalize(
use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv
)
def prepare_static_weights_for_trtllm_fp4_moe( def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant, # args_dequant,
# args, # args,
...@@ -526,4 +541,4 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( ...@@ -526,4 +541,4 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
w2_scale = swizzle_blockscale(w2_scale) w2_scale = swizzle_blockscale(w2_scale)
return w13, w13_scale, w13_scale_2, a13_scale, w2, w2_scale, w2_scale_2, a2_scale return w13, w13_scale, w13_scale_2, a13_scale, w2, w2_scale, w2_scale_2, a2_scale
\ No newline at end of file
...@@ -4,8 +4,15 @@ from enum import Enum ...@@ -4,8 +4,15 @@ from enum import Enum
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
...@@ -156,6 +163,18 @@ def make_fp8_moe_alpha_scales_for_fi( ...@@ -156,6 +163,18 @@ def make_fp8_moe_alpha_scales_for_fi(
return g1_alphas, g2_alphas return g1_alphas, g2_alphas
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
# Propagate block-scale flag so prepare/finalize can skip act quantization
# and inform the kernel to consume per-block weight scales.
return create_flashinfer_prepare_finalize(
use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
)
def get_flashinfer_moe_backend() -> FlashinferMoeBackend: def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = { backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS, "throughput": FlashinferMoeBackend.CUTLASS,
...@@ -293,4 +312,4 @@ def prepare_fp8_moe_layer_for_fi( ...@@ -293,4 +312,4 @@ def prepare_fp8_moe_layer_for_fi(
w2_input_scale=w2_input_scale, w2_input_scale=w2_input_scale,
) )
return w13, w2, w13_scale return w13, w2, w13_scale
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment