Commit d76fc11e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc1' into v0.15.0rc1-dev

parents 38166ec4 58996f35
......@@ -576,9 +576,6 @@ class FusedMoE(CustomOp):
device=vllm_config.device_config.device,
routing_method=self.routing_method_type,
)
self.moe_config_use_flashinfer_cutlass_kernels = (
self.moe_config.use_flashinfer_cutlass_kernels
)
if self.use_mori_kernels:
assert self.rocm_aiter_fmoe_enabled, (
"Mori needs to be used with aiter fused_moe for now."
......@@ -671,6 +668,11 @@ class FusedMoE(CustomOp):
# This is called after all weight loading and post-processing, so it
# should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None:
# NOTE(rob): WIP refactor. For quant methods that own the MK
# we create the MK during process_weights_after_loading.
if self.quant_method.supports_internal_mk or self.quant_method.is_monolithic:
return None
self.ensure_moe_quant_config_init()
# routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend.
......@@ -753,14 +755,6 @@ class FusedMoE(CustomOp):
def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels
@property
def use_flashinfer_cutlass_kernels(self):
return (
self.moe_quant_config is not None
and self.moe_quant_config.quant_dtype == "nvfp4"
and self.moe_config_use_flashinfer_cutlass_kernels
)
@property
def use_marlin_kernels(self):
return getattr(self.quant_method, "use_marlin", False)
......@@ -771,7 +765,7 @@ class FusedMoE(CustomOp):
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
or self.moe_parallel_config.use_fi_all2allv_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@property
......@@ -1571,7 +1565,7 @@ class FusedMoE(CustomOp):
assert self.quant_method is not None
return (
isinstance(self.quant_method, FusedMoEModularMethod)
and self.quant_method.fused_experts.output_is_reduced()
and self.quant_method.moe_mk.output_is_reduced() # type: ignore[union-attr]
)
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
......@@ -1835,7 +1829,7 @@ class FusedMoE(CustomOp):
self.ensure_dp_chunking_init()
has_separate_shared_experts = (
not isinstance(self.quant_method, FusedMoEModularMethod)
not self.quant_method.mk_owns_shared_expert
and self.shared_experts is not None
)
......@@ -1859,8 +1853,10 @@ class FusedMoE(CustomOp):
hidden_states, router_logits, has_separate_shared_experts
)
do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance(
self.quant_method, FusedMoEModularMethod
# NOTE(rob): once we finish migrating all the quant methods to use
# MKs, we can remove the naive dispatch/combine path from here.
do_naive_dispatch_combine = (
self.dp_size > 1 and not self.quant_method.supports_internal_mk
)
ctx = get_forward_context()
......@@ -1888,7 +1884,7 @@ class FusedMoE(CustomOp):
else:
hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch(
dispatch_res = get_ep_group().dispatch_router_logits(
hidden_states_to_dispatch,
router_logits,
self.is_sequence_parallel,
......
......@@ -180,6 +180,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> PrepareResultType:
"""
Perform any quantization (and/or) dispatching needed for this kernel.
......@@ -192,6 +193,9 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
in cases where the compute kernel expects unquantized inputs
Returns a tuple of:
- quantized + dispatched a.
......@@ -220,6 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> tuple[Callable, ReceiverType] | ReceiverType:
"""
Perform any quantization (and/or) dispatching needed for this kernel
......@@ -235,6 +240,9 @@ class FusedMoEPrepareAndFinalize(ABC):
space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
in cases where the compute kernel expects unquantized inputs
Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as
......@@ -407,10 +415,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@staticmethod
def expects_unquantized_inputs(
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
@property
def expects_unquantized_inputs(self) -> bool:
"""
Whether or not the PrepareFinalize should defer input quantization
in the prepare step. If True, then the Experts kernel will
......@@ -1069,6 +1075,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
)
else:
# Overlap shared expert compute with all2all dispatch.
......@@ -1081,6 +1088,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
)
# TODO(lucas): refactor this in the alternative schedules followup
......
......@@ -3,70 +3,6 @@
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: torch.Tensor | None,
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: torch.Tensor | None,
block_m: int,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.size(1)
tokens_in_chunk = curr_hidden_states.size(0)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True
)
inv_perm: torch.Tensor | None = None
num_tokens = top_k_num * tokens_in_chunk
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num)
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, inv_perm)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: torch.Tensor | None,
topk_weight: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.size()
K = curr_hidden.size(-1)
if inv_perm is not None:
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
if not apply_router_weight_on_input:
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def moe_permute(
hidden_states: torch.Tensor,
......
......@@ -58,6 +58,7 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
"""
Returns a tuple of:
......@@ -69,6 +70,11 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
"""
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert not apply_router_weight_on_input, (
"mori does not support apply_router_weight_on_input=True now."
)
......
......@@ -8,6 +8,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
......@@ -17,9 +20,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
......@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
......@@ -330,9 +331,16 @@ def select_fp8_moe_backend(
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No FP8 MoE backend supports the deployment configuration."
)
# TODO(rob): per discussion with TPU team, we need a way to register
# MoE backends by OOT plugins, rather than having an explicit list
# of AVAILBLE_BACKENDS. Enabling returning `Fp8MoeBackend.NONE` is
# a temporary measure until these register APIs are complete.
if current_platform.is_cuda() or current_platform.is_rocm():
raise NotImplementedError(
"No FP8 MoE backend supports the deployment configuration."
)
return Fp8MoeBackend.NONE, None
def convert_to_fp8_moe_kernel_format(
......@@ -457,68 +465,52 @@ def make_fp8_moe_quant_config(
)
def make_fp8_moe_kernel_for_mkm(
def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> mk.FusedMoEPermuteExpertsUnpermute:
fp8_backend: Fp8MoeBackend,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> tuple[mk.FusedMoEModularKernel, bool]:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
experts = experts_cls(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens_per_rank,
quant_config=moe_quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
quant_config=quant_config,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
fp8_backend: Fp8MoeBackend,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> tuple[mk.FusedMoEModularKernel, bool]:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"Fp8 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
quant_config=moe_quant_config,
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class.
kernel = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts=None,
shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
)
......
......@@ -7,6 +7,9 @@ import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
......@@ -14,9 +17,6 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
......@@ -391,67 +391,51 @@ def make_nvfp4_moe_quant_config(
)
def make_nvfp4_moe_kernel_for_mkm(
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> mk.FusedMoEPermuteExpertsUnpermute:
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
experts = experts_cls(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens_per_rank,
quant_config=moe_quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
quant_config=quant_config,
quant_config=moe_quant_config,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> mk.FusedMoEModularKernel:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"NvFP4 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class.
kernel = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts=None,
shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
)
......
......@@ -106,7 +106,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
......@@ -274,6 +281,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
......@@ -283,6 +291,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map,
apply_router_weight_on_input,
quant_config,
defer_input_quant=defer_input_quant,
)
hook()
return receiver()
......
......@@ -4,18 +4,25 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_ep_group
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous,
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def __init__(self, defer_input_quant: bool = False) -> None:
class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
is_sequence_parallel: bool = False,
num_dispatchers: int = 1,
) -> None:
super().__init__()
self.defer_input_quant = defer_input_quant
self.is_sequence_parallel = is_sequence_parallel
self._num_dispatchers = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
......@@ -27,6 +34,113 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return self._num_dispatchers
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quantization to the MoE kernel.
use_nvfp4 = quant_config.use_nvfp4_w4a4
if defer_input_quant:
a1q = a1
a1q_scale = None
else:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled=False,
)
# Skip gathering scales if we have static quantization
# (the scale is a scalar, replicated on all ranks) or
# if quantization is deferred.
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
scales = None if skip_gather_scales else [a1q_scale]
res = get_ep_group().dispatch(
a1q,
topk_weights,
topk_ids,
is_sequence_parallel=self.is_sequence_parallel,
extra_tensors=scales,
)
if skip_gather_scales:
a1q, topk_weights, topk_ids = res
else:
a1q, topk_weights, topk_ids, scales = res
assert scales is not None and len(scales) == 1
a1q_scale = scales[0]
if quant_config.quant_dtype == "nvfp4":
assert a1q_scale is not None
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
out = weight_and_reduce_impl.apply(
output=None,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
output.copy_(
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
)
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return 1
......@@ -42,6 +156,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
......@@ -54,12 +169,17 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
# which use a single kernel call for quant + experts.
if self.defer_input_quant:
if defer_input_quant:
return a1, None, None, None, None
input_sf = (
quant_config.a1_gscale
if quant_config.use_nvfp4_w4a4
else quant_config.a1_scale
)
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_scale,
input_sf,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
......
......@@ -287,17 +287,14 @@ def rocm_aiter_fused_experts(
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def expects_unquantized_inputs(
fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
# AITER fused MoE kernels handle input quantization internally.
return True
@staticmethod
def _supports_current_device() -> bool:
return rocm_aiter_ops.is_fused_moe_enabled()
......@@ -329,7 +326,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
return not moe_parallel_config.use_fi_all2allv_kernels
def supports_expert_map(self):
return True
......
......@@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE):
use_overlapped
and not (
(self.enable_eplb and backend != "allgather_reducescatter")
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
or self.moe_parallel_config.use_fi_all2allv_kernels
)
and self._shared_experts is not None
)
......
......@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import (
sharded_weight_loader,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
......@@ -502,6 +503,9 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1,
)
# Check if running on Blackwell (SM100+) for kernel tuning
self.is_blackwell = current_platform.is_device_capability_family(100)
def forward_native(
self,
hidden_states: torch.Tensor,
......@@ -883,6 +887,7 @@ class MambaMixer2(MambaBase, CustomOp):
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
is_blackwell=self.is_blackwell,
)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
......
......@@ -286,6 +286,7 @@ def selective_state_update(
out=None,
num_accepted_tokens=None,
cu_seqlens=None,
is_blackwell=False,
):
"""
Argument:
......@@ -391,17 +392,26 @@ def selective_state_update(
if dst_state_batch_indices is not None
else (0, 0)
)
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = (
(32, 4)
if dstate <= 16
else (
(16, 4)
if dstate <= 32
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
)
)
# We don't want autotune since it will overwrite the state.
# We instead tune by hand based on dstate.
# Default
BLOCK_SIZE_M, num_warps = 4, 8
if dstate <= 16:
BLOCK_SIZE_M, num_warps = 32, 4
elif dstate <= 32:
BLOCK_SIZE_M, num_warps = 16, 4
elif dstate <= 64:
BLOCK_SIZE_M, num_warps = 8, 4
else:
# dstate > 64
if is_blackwell:
# Optimized for B200 with dstate>64
BLOCK_SIZE_M, num_warps = 32, 8
elif dstate <= 128:
BLOCK_SIZE_M, num_warps = 4, 4
tie_hdim = (
A.stride(-1) == 0
and A.stride(-2) == 0
......
......@@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
......@@ -53,7 +52,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
is_global_sf_supported_for_nvfp4_backend,
make_mxfp4_moe_quant_config,
make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
......@@ -67,7 +65,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe,
......@@ -244,7 +241,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.group_size = 32
self.mxfp4_backend = NvFp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts
self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights(
self,
......@@ -321,7 +317,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def process_weights_after_loading(self, layer: FusedMoE) -> None:
layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False
)
......@@ -336,10 +332,12 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None:
self.kernel = make_nvfp4_moe_kernel(
self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
def apply(
......@@ -349,8 +347,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None
return self.kernel(
assert self.moe_mk is not None
return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -381,19 +379,10 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
activation_key=None if use_a16 else kNvfp4Dynamic,
)
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights(
self,
layer: torch.nn.Module,
......@@ -507,7 +496,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def process_weights_after_loading(self, layer: FusedMoE) -> None:
"""
Convert NVFP4 MoE weights into kernel format and setup the kernel.
"""
......@@ -573,48 +562,33 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
if self.moe_quant_config:
assert self.experts_cls is not None
self.kernel = make_nvfp4_moe_kernel(
self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
assert self.experts_cls is not None
return make_nvfp4_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def get_fused_moe_quant_config(
......@@ -685,8 +659,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
)
else:
assert self.kernel is not None
return self.kernel(
assert self.moe_mk is not None
return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -760,15 +734,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
allow_vllm_cutlass=True,
)
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights(
self,
layer: torch.nn.Module,
......@@ -928,7 +893,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def process_weights_after_loading(self, layer: FusedMoE) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
......@@ -947,7 +912,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
pass
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def process_weights_after_loading(self, layer: FusedMoE) -> None:
# Allow for accessing weights and scales in standard way.
w13 = layer.w13_weight
w2 = layer.w2_weight
......@@ -1009,49 +975,34 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
if self.moe_quant_config:
assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel(
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
assert self.experts_cls is not None
return make_fp8_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def get_fused_moe_quant_config(
......@@ -1142,8 +1093,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.kernel is not None
return self.kernel(
assert self.moe_mk is not None
return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
......
......@@ -16,6 +16,9 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale,
)
from vllm.model_executor.parameter import (
......@@ -159,9 +162,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
if self.backend == "fbgemm":
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False
)
# Pad weights for CUTLASS/FlashInfer kernel alignment (K and N
# divisible by 32). fbgemm has its own layout requirements.
if self.backend in ("cutlass", "flashinfer-cutlass"):
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight_packed.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
layer.weights_padding_cols = 0
layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False
)
layer.alpha = Parameter(
1 / (layer.input_global_scale * layer.weight_global_scale),
......@@ -187,7 +201,8 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
return out
output_dtype = x.dtype
output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]]
output_size = layer.output_size_per_partition
output_shape = [*x.shape[:-1], output_size]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(
......@@ -197,6 +212,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
backend=self.backend,
)
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = (
x_fp4,
layer.weight_packed,
......@@ -221,6 +240,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
......@@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
......@@ -53,7 +52,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
......@@ -679,15 +677,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_vllm_cutlass=False,
)
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights(
self,
layer: Module,
......@@ -813,7 +802,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel(
self,
layer: Module,
layer: FusedMoE,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
......@@ -845,16 +834,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
if self.moe_quant_config:
assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel(
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def process_weights_after_loading(self, layer: Module) -> None:
......@@ -909,33 +897,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
assert self.experts_cls is not None
return make_fp8_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def get_fused_moe_quant_config(
......@@ -1037,9 +1011,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None
assert self.moe_mk is not None
assert not self.is_monolithic
return self.kernel(
return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
......
......@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
......@@ -35,7 +34,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
......@@ -54,13 +52,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
......@@ -84,6 +80,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTokenSym,
kNvfp4Dynamic,
kNvfp4Static,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
......@@ -736,47 +735,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
activation_key=kFp8StaticTensorSym,
)
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
# TRT LLM not supported with all2all yet.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
assert self.experts_cls is not None
return make_fp8_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def create_weights(
......@@ -860,7 +835,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel(
self,
layer: torch.nn.Module,
layer: FusedMoE,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
......@@ -890,11 +865,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel(
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
......@@ -995,8 +972,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
f"but got {layer.activation}"
)
assert self.kernel is not None
return self.kernel(
assert self.moe_mk is not None
return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -1280,9 +1257,16 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False)
else:
# Swizzle block scales and pad the packed NVFP4 weights for kernel
# alignment (CUTLASS/FlashInfer require K and N divisible by 32).
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight = Parameter(weight, requires_grad=False)
def apply(
self,
......@@ -1304,7 +1288,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
)
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(
......@@ -1319,6 +1302,12 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert layer.weight_scale.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
output_size = layer.output_size_per_partition
output_shape = [x.shape[0], output_size]
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = (
x_fp4,
layer.weight,
......@@ -1327,6 +1316,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.alpha,
output_dtype,
)
if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
......@@ -1334,6 +1324,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
......@@ -1360,50 +1353,27 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation_key=kNvfp4Dynamic,
)
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
self.moe
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize(routing_tables)
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
assert self.experts_cls is not None
return make_nvfp4_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def uses_weight_scale_2_pattern(self) -> bool:
......@@ -1528,7 +1498,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def process_weights_after_loading(self, layer: FusedMoE) -> None:
"""
Convert NVFP4 MoE weights into kernel format and setup the kernel.
"""
......@@ -1580,15 +1550,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
if self.moe_quant_config:
assert self.experts_cls is not None
self.kernel = make_nvfp4_moe_kernel(
self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
@property
......@@ -1689,8 +1658,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
)
else:
assert self.kernel is not None
return self.kernel(
assert self.moe_mk is not None
return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
......
......@@ -1053,32 +1053,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16),
None, # routing_bias
x_quant,
x_scale,
layer.w13_weight, # uint8 (e2m1 x 2)
layer.w13_weight_scale, # uint8 (e4m3 x 2)
layer.w13_bias, # fp32 per expert per channel
layer.gemm1_alpha, # fp32 per expert
layer.gemm1_beta, # fp32 per expert
layer.gemm1_clamp_limit, # fp32 per expert
layer.w2_weight, # uint8 (e2m1 x 2)
layer.w2_weight_scale, # ue8m0
layer.w2_bias, # fp32 per expert per channel
None, # output1_scale_scalar
None, # output1_scale_gate_scalar
None, # output2_scale_scalar
layer.global_num_experts,
layer.top_k,
None, # n_group
None, # topk_group
self.intermediate_size, # padded to multiple of 256
layer.ep_rank * layer.local_num_experts, # local_expert_offset
self.num_experts, # local num experts
None, # routed_scaling_factor
1 if layer.renormalize else 0, # routing_method_type, renormalize
True, # do finalize
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
hidden_states=x_quant,
hidden_states_scale=x_scale,
gemm1_weights=layer.w13_weight, # uint8 (e2m1 x 2)
gemm1_weights_scale=layer.w13_weight_scale, # uint8 (e4m3 x 2)
gemm1_bias=layer.w13_bias, # fp32 per expert per channel
gemm1_alpha=layer.gemm1_alpha, # fp32 per expert
gemm1_beta=layer.gemm1_beta, # fp32 per expert
gemm1_clamp_limit=layer.gemm1_clamp_limit, # fp32 per expert
gemm2_weights=layer.w2_weight, # uint8 (e2m1 x 2)
gemm2_weights_scale=layer.w2_weight_scale, # ue8m0
gemm2_bias=layer.w2_bias, # fp32 per expert per channel
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=layer.global_num_experts,
top_k=layer.top_k,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size, # padded to multiple of 256
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=self.num_experts,
routed_scaling_factor=None,
routing_method_type=1 if layer.renormalize else 0,
do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1),
)[0]
return trtllm_gen_output
......
......@@ -15,9 +15,6 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
......@@ -42,7 +39,6 @@ __all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
]
#
......@@ -163,17 +159,6 @@ def reorder_w1w3_to_w3w1(
)
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv"
return create_flashinfer_prepare_finalize(
use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv
)
def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant,
# args,
......
......@@ -4,15 +4,8 @@ from enum import Enum
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
......@@ -163,18 +156,6 @@ def make_fp8_moe_alpha_scales_for_fi(
return g1_alphas, g2_alphas
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
# Propagate block-scale flag so prepare/finalize can skip act quantization
# and inform the kernel to consume per-block weight scales.
return create_flashinfer_prepare_finalize(
use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
)
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
......
......@@ -868,3 +868,70 @@ def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tens
t |= ((nib - 8) & 0xF) << shift
return t
def round_up(x: int, m: int) -> int:
"""Round up x to the nearest multiple of m."""
return (x + m - 1) // m * m
def pad_nvfp4_weight_for_cutlass(
weight: torch.Tensor,
alignment: int = 32,
) -> tuple[torch.Tensor, int]:
"""
Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy
the alignment constraints required by CUTLASS / FlashInfer FP4 kernels.
CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible
by 32 for aligned memory access and efficient tensor core operations.
"""
weight_current_rows = weight.shape[0]
# Pad N dimension (rows) if not aligned
if weight_current_rows % alignment != 0:
total_rows = round_up(weight_current_rows, alignment)
pad_rows = total_rows - weight_current_rows
weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_rows)).contiguous()
# Check K dimension alignment
# 2 FP4 items are packed per byte in the input dimension
weight_current_col_bytes = weight.shape[1]
weight_current_col_elements = weight_current_col_bytes * 2
weights_padding_bytes = 0
if weight_current_col_elements % alignment != 0:
total_cols = round_up(weight_current_col_elements, alignment)
pad_cols = total_cols - weight_current_col_elements
# Convert from FP4 element count to bytes (2 FP4 values per byte)
# pad_cols is always even since alignment=32 and current elements are even
pad_bytes = pad_cols // 2
weight = torch.nn.functional.pad(weight, (0, pad_bytes, 0, 0)).contiguous()
weights_padding_bytes = pad_bytes
return weight, weights_padding_bytes
def pad_nvfp4_activation_for_cutlass(
x_fp4: torch.Tensor,
weights_padding_bytes: int,
) -> torch.Tensor:
"""
Pad packed FP4 activations to match the K-dimension padding applied to weights.
The padding is in bytes (tensor dimension), not FP4 elements.
"""
if weights_padding_bytes > 0:
return torch.nn.functional.pad(x_fp4, (0, weights_padding_bytes)).contiguous()
return x_fp4
def slice_nvfp4_output(
out: torch.Tensor,
output_size: int,
) -> torch.Tensor:
"""
Slice the output tensor to remove padding in N dimension if weight was padded.
"""
if out.shape[-1] != output_size:
return out[..., :output_size].contiguous()
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment