Commit 0da93439 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.1rc0' into v0.18.1rc0-ori

parents 25f2f756 298e5108
......@@ -101,6 +101,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return self.moe_kernel.prepare_finalize.topk_indices_dtype()
return None
@property
def skip_forward_padding(self) -> bool:
"""Whether to skip the padding in the forward before applying the moe method."""
return False
@property
def supports_eplb(self) -> bool:
return False
......
......@@ -11,8 +11,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
......@@ -20,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kMxfp4Static,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
......@@ -142,6 +145,33 @@ def legacy_routing_from_bitmatrix(
return routing_data, gather_idx, scatter_idx
def legacy_routing_from_sparsematrix(
sparse_logits: "SparseMatrix",
n_expts_tot: int,
n_expts_act: int,
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
"""
Creates routing data from a SparseMatrix representation.
"""
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum,
dispatch_indx.shape[0],
)
gate_scal = sparse_logits.vals.flatten()[combine_indx]
routing_data = RoutingData(
gate_scal,
ragged_batch_metadata.block_sizes,
n_expts_tot,
n_expts_act,
ragged_batch_metadata,
)
gather_idx = GatherIndx(combine_indx, dispatch_indx)
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_idx, scatter_idx
def legacy_routing(
logits: torch.Tensor,
n_expts_act: int,
......@@ -158,10 +188,8 @@ def legacy_routing(
if sm_first:
logits = torch.softmax(logits, dim=-1)
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
return legacy_routing_from_bitmatrix(
sparse_logits.mask,
sparse_logits.vals,
sparse_logits.indx,
return legacy_routing_from_sparsematrix(
sparse_logits,
logits.shape[-1],
n_expts_act,
)
......@@ -512,43 +540,43 @@ def make_routing_data(
class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
def _supports_current_device() -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
p = current_platform
if not p.is_cuda_alike():
return False
cap = p.get_device_capability()
if cap is None:
return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
SUPPORTED_W_A = [
(kMxfp4Static, None),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
raise NotImplementedError
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
return True
def supports_expert_map(self) -> bool:
return True
......@@ -605,6 +633,10 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
class OAITritonExperts(BaseOAITritonExperts):
"""OAI Triton-based fused MoE expert implementation."""
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SWIGLUOAI
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
......@@ -689,6 +721,15 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
One use case for it is to inject LoRA modules on the activation and moe_sum.
"""
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
]
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
......@@ -814,3 +855,118 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
)
self.moe_sum(intermediate_cache3.view(-1, topk, K), output)
class OAITritonMxfp4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
"""Monolithic Triton MXFP4 expert. Wraps triton_kernel_moe_forward()."""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
self.topk = moe_config.experts_per_token
self.renormalize = moe_config.routing_method in (
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
if not p.is_cuda_alike():
return False
cap = p.get_device_capability()
if cap is None:
return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kMxfp4Static, None),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SWIGLUOAI
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return (
not moe_parallel_config.use_all2all_kernels
and not moe_parallel_config.enable_eplb
and moe_parallel_config.dp_size <= 1
)
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return routing_method in [
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
@property
def expects_unquantized_inputs(self) -> bool:
return True
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
return triton_kernel_moe_forward(
hidden_states=hidden_states,
w1=w1,
w2=w2,
gating_output=router_logits,
topk=self.topk,
renormalize=self.renormalize,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=self.quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -52,7 +52,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
......@@ -218,7 +217,6 @@ def maybe_roundup_hidden_size(
moe_parallel_config: FusedMoEParallelConfig,
is_lora_enabled: bool,
model_type: str | None,
is_mxfp4_quant: bool,
) -> int:
"""
Given layer hidden size and MoE configurations, round up hidden_size
......@@ -232,7 +230,6 @@ def maybe_roundup_hidden_size(
is used in the case of mxfp4 quantization in selecting the
MxFP4Backend.
model_type: for checking if gpt-oss
is_mxfp4_quant: whether the layer is quantized with mxfp4
Return:
Rounded up hidden_size if rounding up is required based on the configs.
......@@ -246,28 +243,6 @@ def maybe_roundup_hidden_size(
hidden_size, act_dtype, moe_parallel_config
)
# we are padding globally so EP buffer allocation works
if model_type == "gpt_oss" and is_mxfp4_quant:
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend,
get_mxfp4_backend,
)
current_mxfp4_backend = get_mxfp4_backend(is_lora_enabled)
if (
current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
):
hidden_size = round_up(hidden_size, 128)
elif (
current_platform.is_rocm()
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or current_mxfp4_backend == Mxfp4Backend.MARLIN
):
hidden_size = round_up(hidden_size, 256)
return hidden_size
......@@ -504,6 +479,8 @@ class FusedMoE(CustomOp):
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = MoEActivation.from_str(activation)
# TODO(bnell): we should not have to create a router if the kernel is
# monolithic.
self.router = create_fused_moe_router(
top_k=top_k,
global_num_experts=self.global_num_experts,
......@@ -538,9 +515,6 @@ class FusedMoE(CustomOp):
moe_parallel_config=self.moe_parallel_config,
is_lora_enabled=vllm_config.lora_config is not None,
model_type=self.model_type,
is_mxfp4_quant=(
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
),
)
self.hidden_size = hidden_size
......
......@@ -70,16 +70,13 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
"""
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert not apply_router_weight_on_input, (
"mori does not support apply_router_weight_on_input=True now."
)
scale = None
if self.use_fp8_dispatch:
# When defer_input_quant is True, the expert kernel handles
# quantization internally, so skip FP8 dispatch quantization.
if self.use_fp8_dispatch and not defer_input_quant:
from aiter import QuantType, get_hip_quant
if quant_config.is_block_quantized:
......
......@@ -444,7 +444,7 @@ def convert_to_fp8_moe_kernel_format(
Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.FLASHINFER_TRTLLM,
]:
w13, w2, w13_scale = prepare_fp8_moe_layer_for_fi(
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_fi(
layer=layer,
w13=w13,
w2=w2,
......@@ -512,6 +512,21 @@ def make_fp8_moe_quant_config(
g1_alphas=(w1_scale * a1_scale).squeeze(),
g2_alphas=(w2_scale * a2_scale).squeeze(),
)
# MXFP8 uses "mxfp8" quant_dtype so the prepare step dispatches to
# _mxfp8_e4m3_quantize rather than standard FP8 block quantization.
# Non-swizzled layout is required since the TRTLLM kernel expects
# scales in (num_tokens, hidden_dim // 32) format.
if block_shape == [1, 32]:
return FusedMoEQuantConfig.make(
"mxfp8",
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
is_nvfp4_scale_swizzled=False,
)
# All other backends use normal config.
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Union
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig,
)
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
mxfp4_mxfp8_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
_swizzle_mxfp4,
get_padding_alignment,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kMxfp4Static,
kMxfp8Dynamic,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
if has_triton_kernels():
try:
from triton_kernels.matmul_ogs import PrecisionConfig
except (ImportError, AttributeError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
"version is compatible. Error: %s",
e,
)
class Mxfp4MoeBackend(Enum):
NONE = "None"
# FlashInfer TRTLLM backends
FLASHINFER_TRTLLM_MXFP4_MXFP8 = "FLASHINFER_TRTLLM_MXFP4_MXFP8"
FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_TRTLLM_MXFP4_BF16"
# FlashInfer CUTLASS backends
FLASHINFER_CUTLASS_MXFP4_MXFP8 = "FLASHINFER_CUTLASS_MXFP4_MXFP8"
FLASHINFER_CUTLASS_MXFP4_BF16 = "FLASHINFER_CUTLASS_MXFP4_BF16"
# Marlin
BATCHED_MARLIN = "BATCHED_MARLIN"
MARLIN = "MARLIN"
# ROCm AITER (CK)
CK = "CK"
# Triton
TRITON = "TRITON"
TRITON_UNFUSED = "TRITON_UNFUSED"
# XPU
XPU = "XPU"
# Backends that share the same TRTLLM weight format
TRTLLM_BACKENDS = (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
)
TRITON_BACKENDS = (
Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.TRITON_UNFUSED,
)
def backend_to_kernel_cls(
backend: Mxfp4MoeBackend,
) -> list[type[mk.FusedMoEExperts]]:
if backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
):
from vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe import (
TrtLlmMxfp4ExpertsModular,
TrtLlmMxfp4ExpertsMonolithic,
)
# NOTE: prefer Monolithic > Modular, so return Monolithic first.
return [TrtLlmMxfp4ExpertsMonolithic, TrtLlmMxfp4ExpertsModular]
elif backend in (
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
return [FlashInferExperts]
elif backend == Mxfp4MoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts,
OAITritonMxfp4ExpertsMonolithic,
)
# NOTE: prefer Monolithic > Modular, so return Monolithic first.
return [OAITritonMxfp4ExpertsMonolithic, OAITritonExperts]
elif backend == Mxfp4MoeBackend.TRITON_UNFUSED:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
UnfusedOAITritonExperts,
)
return [UnfusedOAITritonExperts]
elif backend == Mxfp4MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
return [MarlinExperts]
elif backend == Mxfp4MoeBackend.BATCHED_MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
)
return [BatchedMarlinExperts]
elif backend == Mxfp4MoeBackend.CK:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
return [AiterExperts]
elif backend == Mxfp4MoeBackend.XPU:
raise NotImplementedError("XPU backend uses XpuMxfp4MoEMethod directly.")
else:
raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}")
def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
"""Map user's moe_backend string to Mxfp4MoeBackend."""
mapping: dict[str, Mxfp4MoeBackend] = {
"flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
"flashinfer_trtllm_afp8": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
"flashinfer_cutlass": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
"flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
"triton": Mxfp4MoeBackend.TRITON,
"marlin": Mxfp4MoeBackend.MARLIN,
"ck": Mxfp4MoeBackend.CK,
}
if backend := mapping.get(runner_backend):
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for MXFP4 MoE. "
f"Expected one of {list(mapping.keys())}."
)
def _get_priority_backends() -> list[Mxfp4MoeBackend]:
"""
Get available backends in priority order based on platform and config.
Only includes BF16 backends. MXFP8 backends are selected via env vars.
"""
_AVAILABLE_BACKENDS = [
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.CK,
Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.TRITON_UNFUSED,
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
]
return _AVAILABLE_BACKENDS
def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
"""Map backend to its activation key (MXFP8 or None for BF16)."""
if backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
return kMxfp8Dynamic
return None
def select_mxfp4_moe_backend(
config: FusedMoEConfig,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]:
"""
Select the primary MXFP4 MoE backend.
Note: Shape-specific fallbacks may still occur at runtime.
"""
triton_kernels_supported = has_triton_kernels() and (
9,
0,
) <= current_platform.get_device_capability() < (11, 0)
# LoRA: separate experts backend path
if config.is_lora_enabled:
if not current_platform.is_cuda():
raise NotImplementedError("Mxfp4 LoRA only supported on CUDA Platform.")
if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
logger.info_once("Using Triton backend for mxfp4 lora")
return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls(
Mxfp4MoeBackend.TRITON_UNFUSED
)[0]
logger.info_once("Using Marlin backend for mxfp4 lora")
return Mxfp4MoeBackend.MARLIN, backend_to_kernel_cls(Mxfp4MoeBackend.MARLIN)[0]
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
def _make_log_backend(backend: Mxfp4MoeBackend):
return f"Using '{backend.value}' Mxfp4 MoE backend."
def _make_log_unsupported(backend: Mxfp4MoeBackend, reason: str | None) -> str:
if reason:
return (
f"Mxfp4 MoE backend '{backend.value}' does not support the "
f"deployment configuration since {reason}."
)
return (
f"Mxfp4 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: Mxfp4MoeBackend,
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts]]:
reason: str | None = None
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_mxfp4_backend(runner_backend)
if (
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
and requested_backend == Mxfp4MoeBackend.MARLIN
):
requested_backend = Mxfp4MoeBackend.BATCHED_MARLIN
return _return_or_raise(
requested_backend,
config,
kMxfp4Static,
_backend_activation_key(requested_backend),
activation_format,
)
# Select kernels in order of backend.
AVAILABLE_BACKENDS = _get_priority_backends()
# Handle explicit FlashInfer MXFP4 BF16 configuration.
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
if not envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16)
AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16)
else:
if current_platform.is_device_capability(90):
return _return_or_raise(
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
config,
kMxfp4Static,
None,
activation_format,
)
if current_platform.is_device_capability_family(100):
return _return_or_raise(
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
config,
kMxfp4Static,
None,
activation_format,
)
raise ValueError(
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16=1 is set but the "
"current device capability is not supported. "
"Only SM90 (CUTLASS) and SM100+ (TRTLLM) are supported."
)
# Handle explicit FlashInfer MXFP4 MXFP8 TRTLLM configuration.
if (
envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
):
return _return_or_raise(
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
config,
kMxfp4Static,
kMxfp8Dynamic,
activation_format,
)
# Handle explicit FlashInfer MXFP4 MXFP8 CUTLASS configuration.
if (
envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS")
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
):
return _return_or_raise(
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
config,
kMxfp4Static,
kMxfp8Dynamic,
activation_format,
)
# Handle explicit Marlin MXFP4 configuration.
if envs.is_set("VLLM_MXFP4_USE_MARLIN") and envs.VLLM_MXFP4_USE_MARLIN:
return _return_or_raise(
Mxfp4MoeBackend.MARLIN,
config,
kMxfp4Static,
None,
activation_format,
)
for backend in AVAILABLE_BACKENDS:
activation_key = _backend_activation_key(backend)
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, kMxfp4Static, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
if current_platform.is_xpu():
backend = Mxfp4MoeBackend.XPU
logger.info_once(_make_log_backend(backend))
return backend, None
if current_platform.is_cuda() or current_platform.is_rocm():
raise NotImplementedError(
"No MXFP4 MoE backend supports the deployment configuration."
)
return Mxfp4MoeBackend.NONE, None
def mxfp4_round_up_hidden_size_and_intermediate_size(
backend: Mxfp4MoeBackend, hidden_size: int, intermediate_size: int
) -> tuple[int, int]:
"""Round up hidden_size and intermediate_size based on backend requirements."""
if backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
intermediate_size = round_up(intermediate_size, 128)
if current_platform.is_xpu():
hidden_size = round_up(hidden_size, 128)
else:
hidden_size = round_up(hidden_size, 256)
elif backend in TRTLLM_BACKENDS:
intermediate_size = round_up(intermediate_size, 256)
hidden_size = round_up(hidden_size, 256)
elif backend in (
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
intermediate_size = round_up(intermediate_size, 128)
hidden_size = round_up(hidden_size, 128)
elif current_platform.is_rocm():
pad_align = get_padding_alignment()
intermediate_size = round_up(intermediate_size, pad_align)
hidden_size = round_up(hidden_size, pad_align)
else:
intermediate_size = round_up(intermediate_size, 64)
return hidden_size, intermediate_size
def convert_to_mxfp4_moe_kernel_format(
mxfp4_backend: Mxfp4MoeBackend,
layer: torch.nn.Module,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
_cache_permute_indices: dict[torch.Size, torch.Tensor] | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor,
Union[torch.Tensor, "PrecisionConfig"],
Union[torch.Tensor, "PrecisionConfig"],
torch.Tensor | None,
torch.Tensor | None,
]:
"""Convert loaded weights into backend-specific kernel format."""
num_experts = w13_weight.shape[0]
intermediate_size = w13_weight.shape[1] // 2
hidden_size = w13_weight.shape[2] * 2
sf_block_size = 32 # mxfp4 block size
if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_mxfp4_layer_for_marlin,
)
return prepare_moe_mxfp4_layer_for_marlin(
layer,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
elif mxfp4_backend in TRTLLM_BACKENDS:
assert _cache_permute_indices is not None
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
# gemm1_alpha/beta/clamp_limit are created by the expert class
# (TrtLlmMxfp4ExpertsBase), not on the layer.
w13_weight = w13_weight.data
w2_weight = w2_weight.data
w13_weight_scale = w13_weight_scale.data
w2_weight_scale = w2_weight_scale.data
assert w13_bias is not None and w2_bias is not None
w13_bias = w13_bias.data.to(torch.float32)
w2_bias = w2_bias.data.to(torch.float32)
# Swap w1 and w3 as the definition of swiglu is different in trtllm-gen
def swap_every_two_rows(x, axis=-1):
shape = x.shape
if axis < 0:
axis = len(shape) + axis
new_shape = list(shape)
new_shape[axis] = shape[axis] // 2
new_shape.insert(axis + 1, 2)
x = x.reshape(*new_shape)
x = x.flip(axis + 1)
new_shape = list(shape)
return x.reshape(*new_shape)
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
w13_weight = swap_every_two_rows(w13_weight, -2)
w13_bias = swap_every_two_rows(w13_bias, -1)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_shuffled = []
gemm1_scales_shuffled = []
gemm2_weights_shuffled = []
gemm2_scales_shuffled = []
gemm1_bias_shuffled = []
gemm2_bias_shuffled = []
epilogue_tile_m = 128
for i in range(num_experts):
# w13 weight
permute_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_shuffled.append(
w13_weight[i]
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
.contiguous()
)
# w13 scale
permute_sf_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_shuffled.append(
nvfp4_block_scale_interleave(
w13_weight_scale[i]
.view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)]
.contiguous()
)
)
# w13 bias
permute_bias_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm1_bias_shuffled.append(
w13_bias[i]
.clone()
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
.contiguous()
)
# w2 weight
permute_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_shuffled.append(
w2_weight[i]
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
.contiguous()
)
# w2 scale
permute_sf_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_shuffled.append(
nvfp4_block_scale_interleave(
w2_weight_scale[i]
.view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)]
.contiguous()
)
)
# w2 bias
permute_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm2_bias_shuffled.append(
w2_bias[i]
.clone()
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
.contiguous()
)
w13_weight = torch.stack(gemm1_weights_shuffled)
w13_weight_scale = (
torch.stack(gemm1_scales_shuffled)
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
w2_weight = torch.stack(gemm2_weights_shuffled)
w2_weight_scale = (
torch.stack(gemm2_scales_shuffled)
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
return (
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
elif mxfp4_backend in (
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
# De-interleave and swap for w13 weight, bias, and scales
w13_w = w13_weight.data
gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
assert w13_bias is not None and w2_bias is not None
w13_b = w13_bias.data.to(torch.float32)
gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
w13_s = w13_weight_scale.data
gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
w13_scale_swapped = torch.cat([s3, s1], dim=1)
if mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8:
from flashinfer import block_scale_interleave
orig_shape = w13_scale_swapped.shape
w13_scale_interleaved = block_scale_interleave(
w13_scale_swapped.view(torch.uint8)
).reshape(orig_shape)
w2_s = w2_weight_scale.data
orig_shape = w2_s.shape
w2_scale_interleaved = block_scale_interleave(
w2_s.view(torch.uint8)
).reshape(orig_shape)
return (
w13_weight_swapped,
w2_weight,
w13_scale_interleaved,
w2_scale_interleaved,
w13_bias_swapped,
w2_bias,
)
else:
assert mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16
def _interleave_mxfp4_cutlass_sm90(w):
w_shape = w.shape
w_interleaved = w.reshape(w_shape[0], w_shape[1], (w_shape[2] // 4), 4)
w_interleaved = w_interleaved.permute(0, 2, 1, 3)
w_interleaved = w_interleaved.reshape(
w_shape[0], w_shape[2] // 4, w_shape[1] * 4
)
return w_interleaved
w31_scales = w13_scale_swapped.to(torch.uint8)
w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
w2_scale = w2_weight_scale.data.to(torch.uint8)
w2_scale_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scale)
return (
w13_weight_swapped,
w2_weight,
w31_scales_interleaved,
w2_scale_interleaved,
w13_bias_swapped,
w2_bias,
)
elif mxfp4_backend == Mxfp4MoeBackend.CK:
from vllm._aiter_ops import rocm_aiter_ops
if w13_bias is not None:
w13_bias = w13_bias.data.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.data.to(torch.float32)
e, n, k = w13_weight.shape
# De-interleave w13 rows: gate/up pairs -> contiguous gate, up blocks
w13_weight.view(torch.uint8).copy_(
w13_weight.data.view(torch.uint8)
.view(e, n // 2, 2, k)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, k)
)
w13_weight_scale.data = (
w13_weight_scale.data.view(e, n // 2, 2, -1)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, -1)
)
# View as native FP4 dtype for AITER shuffle
w13_weight.data = w13_weight.data.view(torch.float4_e2m1fn_x2)
w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2)
# Shuffle weights and scales for AITER CK kernel layout
w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w13_weight_scale.view(-1, w13_weight_scale.shape[-1]),
num_experts,
True,
)
w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w2_weight_scale.view(-1, w2_weight_scale.shape[-1]),
num_experts,
False,
)
# Permute bias to match de-interleaved weight layout
if w13_bias is not None:
w13_bias = (
w13_bias.data.view(-1, n // 2, 2)
.permute(0, 2, 1)
.contiguous()
.view(-1, n)
)
return (
w13_weight,
w2_weight,
shuffled_w13_scale,
shuffled_w2_scale,
w13_bias,
w2_bias,
)
elif mxfp4_backend in TRITON_BACKENDS:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
assert w13_bias is not None and w2_bias is not None
w13_bias = w13_bias.to(torch.float32)
w2_bias = w2_bias.to(torch.float32)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
w13_weight,
w13_weight_scale,
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
w2_weight,
w2_weight_scale,
)
w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
del layer.w13_weight
del layer.w2_weight
return (
w13_weight,
w2_weight,
w13_precision_config,
w2_precision_config,
w13_bias,
w2_bias,
)
else:
raise ValueError(
f"Unsupported mxfp4_backend: {mxfp4_backend}: "
f"should be one of: {list(Mxfp4MoeBackend)}."
)
def make_mxfp4_moe_quant_config(
mxfp4_backend: Mxfp4MoeBackend,
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig | None:
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
if mxfp4_backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
return mxfp4_mxfp8_moe_quant_config(
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
elif mxfp4_backend in (
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.TRITON_UNFUSED,
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.CK,
):
return mxfp4_w4a16_moe_quant_config(
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
else:
return ocp_mx_moe_quant_config(
quant_dtype="mxfp4",
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
def make_mxfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEExperts],
mxfp4_backend: Mxfp4MoeBackend,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEKernel:
"""Create a FusedMoEKernel for the given MXFP4 backend."""
is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic)
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
use_monolithic=is_monolithic,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__, scope="local")
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_deepep_ll_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
inplace=(
not moe_config.disable_inplace and mxfp4_backend not in TRTLLM_BACKENDS
),
)
return kernel
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
backend_to_kernel_cls,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kMxfp8Dynamic,
kMxfp8Static,
)
logger = init_logger(__name__)
_SUPPORTED_BACKENDS: frozenset[Fp8MoeBackend] = frozenset(
{
Fp8MoeBackend.FLASHINFER_TRTLLM,
}
)
class MxFp8MoeBackend(Enum):
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
_BACKEND_NAME_MAP: dict[str, Fp8MoeBackend] = {
"flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM,
}
def _select_kernel_cls(
backend: Fp8MoeBackend,
config: FusedMoEConfig,
) -> type[mk.FusedMoEExperts]:
"""Select the first supported expert class for the MXFP8 config."""
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
last_reason: str | None = None
for cls in backend_to_kernel_cls(backend):
supported, reason = cls.is_supported_config(
cls,
config,
kMxfp8Static,
kMxfp8Dynamic,
activation_format,
)
if supported:
return cls
last_reason = reason
raise ValueError(
f"No supported MXFP8 expert class for {backend.value}: {last_reason}"
)
def select_mxfp8_moe_backend(
config: FusedMoEConfig,
) -> MxFp8MoeBackend:
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
"""Select the MXFP8 MoE backend and the best expert class.
Returns:
A tuple of (fp8_backend, experts_cls).
"""
if config.is_lora_enabled:
raise NotImplementedError("LoRA is not supported for MXFP8 MoE.")
AVAILABLE_BACKENDS = [
MxFp8MoeBackend.FLASHINFER_TRTLLM,
]
runner_backend = config.moe_backend
if runner_backend != "auto":
mapping = {
"flashinfer_trtllm": MxFp8MoeBackend.FLASHINFER_TRTLLM,
}
if backend := mapping.get(runner_backend):
logger.info_once(
"Using '%s' MxFp8 MoE backend (user-requested).",
backend.value,
backend = _BACKEND_NAME_MAP.get(runner_backend)
if backend is None:
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for "
f"MXFP8 MoE. Expected one of "
f"{list(_BACKEND_NAME_MAP.keys())}."
)
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for MXFP8 MoE. "
f"Expected one of {list(mapping.keys())}."
logger.info_once(
"Using '%s' MxFp8 MoE backend (user-requested).",
backend.value,
)
return backend, _select_kernel_cls(backend, config)
# Auto-select: pick the first supported backend.
for backend in _SUPPORTED_BACKENDS:
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
return backend, _select_kernel_cls(backend, config)
# Auto-select: only one backend available for now.
backend = AVAILABLE_BACKENDS[0]
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
return backend
raise ValueError("No MXFP8 MoE backends available.")
......@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
mxfp4_w4a16_moe_quant_config,
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
......@@ -347,16 +346,6 @@ def convert_to_nvfp4_moe_kernel_format(
)
def make_mxfp4_moe_quant_config(
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
return mxfp4_w4a16_moe_quant_config(
w1_scale=w13_scale,
w2_scale=w2_scale,
)
def make_nvfp4_moe_quant_config(
backend: NvFp4MoeBackend,
w13_scale: torch.Tensor,
......
......@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kMxfp4Static,
)
......@@ -201,6 +202,8 @@ def rocm_aiter_fused_experts(
activation_method = ActivationMethod.SILU
elif activation == MoEActivation.GELU:
activation_method = ActivationMethod.GELU
elif activation == MoEActivation.SWIGLUOAI:
activation_method = rocm_aiter_ops.get_aiter_activation_type("swiglu")
else:
raise ValueError(f"Unsupported activation: {activation}")
......@@ -247,8 +250,8 @@ def rocm_aiter_fused_experts(
else:
quant_method = QuantMethod.NO.value
# quark moe for mxfp4 w_dtype mxfp4 a_dtype
if quant_config.use_mxfp4_w4a4:
# mxfp4: both w4a4 (quark) and w4a16 (oracle CK) use BLOCK_1X32
if quant_config.use_mxfp4_w4a4 or quant_config.use_mxfp4_w4a16:
quant_method = QuantMethod.BLOCK_1X32.value
# w8a8 block-scaled
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
......@@ -289,13 +292,20 @@ def rocm_aiter_fused_experts(
doweight_stage1=apply_router_weight_on_input,
num_local_tokens=num_local_tokens,
output_dtype=output_dtype,
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
)
class AiterExperts(mk.FusedMoEExpertsModular):
@property
def expects_unquantized_inputs(self) -> bool:
return True
# When paired with MoRI, the prepare/finalize handles FP8
# quantization during dispatch to reduce network traffic,
# so we should not defer input quantization.
# Otherwise, AITER fused MoE kernels handle input quantization
# internally via a single fused kernel.
return not self.moe_config.use_mori_kernels
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
......@@ -314,21 +324,23 @@ class AiterExperts(mk.FusedMoEExpertsModular):
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# TODO(rob): AITER also supports MXFP4, which is not
# yet supported via an Oracle. Once it is, we will add
# MXFP4 to this list.
SUPPORTED_W_A = [
(None, None),
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kMxfp4Static, None),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.GELU]
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......
......@@ -3,9 +3,11 @@
import torch
from torch.nn.parameter import Parameter
import vllm._custom_ops as ops
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@PluggableLayer.register("gate_linear")
......@@ -13,8 +15,9 @@ class GateLinear(ReplicatedLinear):
"""MoE gate linear layer with three-tier GEMM dispatch:
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
3. F.linear via ReplicatedLinear (ultimate fallback)
2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims)
3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
4. F.linear via ReplicatedLinear (ultimate fallback)
The ``out_dtype`` attribute is mutable and can be set after init
(e.g. when the required dtype depends on the expert quantization
......@@ -25,6 +28,10 @@ class GateLinear(ReplicatedLinear):
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
def __init__(
self,
input_size: int,
......@@ -65,6 +72,15 @@ class GateLinear(ReplicatedLinear):
and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
)
# gpt-oss specialized kernel eligibility (SM90+, exact dims)
self.allow_gpt_oss_router_gemm = (
self.weight.dtype == torch.bfloat16
and current_platform.is_cuda()
and is_hopper_or_blackwell
and output_size in self.GPT_OSS_SUPPORTED_NUM_EXPERTS
and input_size in self.GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
# cuBLAS bf16→fp32 eligibility
self.allow_cublas_router_gemm = (
self.allow_specialized_router_gemm
......@@ -92,8 +108,6 @@ class GateLinear(ReplicatedLinear):
def forward(
self, x: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
import vllm._custom_ops as ops
# Tier 1: DSV3 specialized kernel
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
output = ops.dsv3_router_gemm(
......@@ -103,15 +117,47 @@ class GateLinear(ReplicatedLinear):
)
return output, None
# Tier 2: cuBLAS bf16→fp32
# Tier 2: gpt-oss specialized kernel
if self.allow_gpt_oss_router_gemm:
output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias)
return output, None
# Tier 3: cuBLAS bf16→fp32
if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
output = ops.router_gemm_bf16_fp32(x, self.weight)
return output, None
# Tier 3: F.linear (ReplicatedLinear)
# Tier 4: F.linear (ReplicatedLinear)
if self.out_dtype is not None and x.dtype != self.weight.dtype:
x = x.to(self.weight.dtype)
output, output_bias = super().forward(x)
if self.out_dtype is not None and output.dtype != self.out_dtype:
output = output.to(self.out_dtype)
return output, output_bias
def gpt_oss_router_gemm_impl(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
"""
Dynamically run min-latency gemm if num_tokens <= 128.
This must be wrapped in a custom op because our torch.compile integration
does not support runtime dispatching on num_tokens.
"""
if x.shape[0] <= 128:
return ops.gpt_oss_router_gemm(x, weight, bias)
else:
return torch.nn.functional.linear(x, weight, bias)
def gpt_oss_router_gemm_fake(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
return x.new_empty((x.shape[0], weight.shape[0]))
direct_register_custom_op(
op_name="gpt_oss_router_gemm",
op_func=gpt_oss_router_gemm_impl,
fake_impl=gpt_oss_router_gemm_fake,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from contextlib import nullcontext
from typing import TYPE_CHECKING
......@@ -82,9 +83,22 @@ def _moe_forward(
layer = get_layer_from_name(_resolve_layer_name(layer_name))
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
return layer.runner.forward_impl(
layer, hidden_states, router_logits, shared_experts_input
)
runner = layer.runner
with runner._sequence_parallel_context():
if runner.use_dp_chunking:
return runner.forward_impl_chunked(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
else:
return runner.forward_impl(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
def _moe_forward_fake(
......@@ -105,9 +119,22 @@ def _moe_forward_shared(
layer = get_layer_from_name(_resolve_layer_name(layer_name))
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
return layer.runner.forward_impl(
layer, hidden_states, router_logits, shared_experts_input
)
runner = layer.runner
with runner._sequence_parallel_context():
if runner.use_dp_chunking:
return runner.forward_impl_chunked(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
else:
return runner.forward_impl(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
def _moe_forward_shared_fake(
......@@ -191,10 +218,17 @@ class DefaultMoERunner(MoERunner):
self.reduce_results = reduce_results
self.enable_dbo = enable_dbo
# Chunked all2all staging tensor
# TODO(bnell) rename these?
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
self._maybe_init_dp_chunking()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
self.use_shared_experts_stream = False
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
self.shared_experts_stream = None
......@@ -210,23 +244,20 @@ class DefaultMoERunner(MoERunner):
# Needed for string -> FusedMoE layer lookup in custom ops.
self.layer_name = layer.layer_name
self.moe_forward = self._select_forward(layer)
def _select_forward(self, layer: torch.nn.Module) -> Callable:
if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
if self.shared_experts is None:
self.moe_forward = _moe_forward
else:
self.moe_forward = _moe_forward_shared
else:
if self.shared_experts is None:
self.moe_forward = torch.ops.vllm.moe_forward
else:
self.moe_forward = torch.ops.vllm.moe_forward_shared
return _moe_forward if self.shared_experts is None else _moe_forward_shared
# Chunked all2all staging tensor
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
return (
torch.ops.vllm.moe_forward
if self.shared_experts is None
else torch.ops.vllm.moe_forward_shared
)
@property
def use_dp_chunking(self) -> bool:
......@@ -241,22 +272,8 @@ class DefaultMoERunner(MoERunner):
self,
hidden_states: torch.Tensor,
shared_input: torch.Tensor | None,
has_separate_shared_experts: bool,
use_chunked_impl: bool,
) -> tuple[bool, torch.Tensor | None]:
use_shared_experts_stream = (
current_platform.is_cuda()
and has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
shared_experts_input: torch.Tensor | None = None
if use_shared_experts_stream:
):
if self.use_shared_experts_stream:
assert self.shared_experts_stream is not None
assert self.moe_config.disable_inplace
......@@ -278,12 +295,11 @@ class DefaultMoERunner(MoERunner):
assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(current_stream())
return use_shared_experts_stream, shared_experts_input
def ensure_dp_chunking_init(self):
if not self.use_dp_chunking or self.batched_hidden_states is not None:
def _maybe_init_dp_chunking(self):
if not self.use_dp_chunking:
return
assert self.batched_hidden_states is None
states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]
......@@ -309,6 +325,38 @@ class DefaultMoERunner(MoERunner):
device=device,
)
@property
def has_separate_shared_experts(self) -> bool:
return (
not self.quant_method.mk_owns_shared_expert
and self.shared_experts is not None
)
def _apply_shared_experts(
self,
hidden_states: torch.Tensor,
allow_streaming: bool = False,
) -> torch.Tensor | None:
shared_output: torch.Tensor | None = None
if self.has_separate_shared_experts:
assert self.shared_experts is not None
if self.use_shared_experts_stream and allow_streaming:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states)
current_stream().wait_stream(self.shared_experts_stream)
else:
shared_output = self.shared_experts(hidden_states)
return shared_output
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
......@@ -322,7 +370,6 @@ class DefaultMoERunner(MoERunner):
Therefore it is required that we reduce the shared_experts output
early.
"""
assert self.quant_method is not None
return (
self.quant_method.moe_kernel is not None
and self.quant_method.moe_kernel.output_is_reduced()
......@@ -357,7 +404,7 @@ class DefaultMoERunner(MoERunner):
return result
return hidden_states
def _reduce_output(
def _maybe_reduce_output(
self,
states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
trunc_sizes: list[int],
......@@ -397,25 +444,21 @@ class DefaultMoERunner(MoERunner):
return "from_forward_context"
return self.layer_name
def forward(
def _maybe_pad_hidden_states(
self,
original_hidden_states: torch.Tensor | None,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed)
if self.shared_experts is not None:
original_hidden_states = hidden_states
original_hidden_dim = hidden_states.shape[-1]
else:
original_hidden_states = None
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states = self.apply_routed_input_transform(hidden_states)
# This is the dimension after transform (for routed expert output slicing)
) -> tuple[torch.Tensor, list[int]]:
original_hidden_dim = (
original_hidden_states.shape[-1]
if original_hidden_states is not None
else 0
)
transformed_hidden_dim = hidden_states.shape[-1]
if self.moe_config.hidden_dim != transformed_hidden_dim:
if (
not self.quant_method.skip_forward_padding
and self.moe_config.hidden_dim != transformed_hidden_dim
):
hidden_states = F.pad(
hidden_states,
(0, self.moe_config.hidden_dim - transformed_hidden_dim),
......@@ -423,134 +466,235 @@ class DefaultMoERunner(MoERunner):
value=0.0,
)
fused_output = self.moe_forward(
hidden_states,
router_logits,
original_hidden_states,
self._encode_layer_name(),
)
if self.shared_experts is not None:
orig_hidden_dims = [original_hidden_dim, transformed_hidden_dim]
else:
orig_hidden_dims = [transformed_hidden_dim]
return self._reduce_output(fused_output, orig_hidden_dims)
return hidden_states, orig_hidden_dims
def forward_impl_chunked(
def _apply_quant_method(
self,
layer: torch.nn.Module,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
full_shared_input: torch.Tensor | None,
has_separate_shared_experts: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_input: torch.Tensor | None,
run_shared_experts_before: bool = True,
) -> tuple[torch.Tensor | None, torch.Tensor]:
shared_input = shared_input if shared_input is not None else hidden_states
shared_output: torch.Tensor | None = None
# Run this before quant_method to avoid inplace issues.
if run_shared_experts_before:
shared_output = self._apply_shared_experts(shared_input, False)
if self.quant_method.is_monolithic:
result = self.quant_method.apply_monolithic(
layer=layer,
x=hidden_states,
router_logits=router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
result = self.quant_method.apply(
layer=layer,
x=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_input,
)
if isinstance(result, tuple):
assert shared_output is None
shared_output, hidden_states = result
else:
hidden_states = result
if not run_shared_experts_before and self.has_separate_shared_experts:
assert shared_output is None
shared_output = self._apply_shared_experts(shared_input, True)
return shared_output, hidden_states
def _sequence_parallel_context(self):
ctx = get_forward_context()
return (
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
if ctx.dp_metadata
else nullcontext()
)
def _allocate_dp_chunking_outputs(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor | None, torch.Tensor]:
assert self.use_dp_chunking
# Assert the inputs are of the proper type and shape.
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
assert self.batched_hidden_states.dtype == hidden_states.dtype, (
f"{self.batched_hidden_states.dtype} == {hidden_states.dtype}"
)
assert self.batched_router_logits.dtype == full_router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
assert self.batched_router_logits.dtype == router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {router_logits.dtype}"
)
# Check size compatibility.
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
# TODO(bnell): Fix shared_expert_inputs w/chunking.
# assert shared_input is None, (
# "Routed input transform is not currently supported with DP chunking."
# )
# Check size compatibility.
assert self.batched_hidden_states.size(-1) == hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == router_logits.size(-1)
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
final_fused_hidden_states = torch.empty_like(hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
chunk_size = chunk_end - chunk_start
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :]
shared_input = (
full_shared_input[chunk_start:chunk_end, :]
if full_shared_input is not None
else None
)
final_shared_hidden_states = torch.empty_like(hidden_states)
else:
final_shared_hidden_states = None
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
# This is only true when DBO has been enabled in the config.
# Both tensors will have an outer dimension for the ubatch id
if self.batched_hidden_states.dim() == 3:
assert self.batched_router_logits.dim() == 3
batch_buffer_idx = dbo_current_ubatch_id()
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
else:
batched_hidden_states = self.batched_hidden_states
batched_router_logits = self.batched_router_logits
return final_shared_hidden_states, final_fused_hidden_states
def _maybe_gate(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
return router_logits
@property
def do_naive_dispatch_combine(self) -> bool:
return (
self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk
)
assert (
batched_hidden_states.size(0) # type: ignore
>= chunk_size
def _maybe_dispatch(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
# router logits to all experts.
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if self.do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
hidden_states,
router_logits,
self.moe_config.is_sequence_parallel,
)
assert (
batched_router_logits.size(0) # type: ignore
>= chunk_size
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstraction to better support PCP.
if self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states,
dim=0,
)
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)
router_logits = get_pcp_group().all_gather(
router_logits,
dim=0,
)
return hidden_states, router_logits
shared_input = (
shared_input if shared_input is not None else staged_hidden_states
def _maybe_combine(
self,
shared_output: torch.Tensor | None,
hidden_states: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
if self.do_naive_dispatch_combine:
hidden_states = get_ep_group().combine(
hidden_states, self.moe_config.is_sequence_parallel
)
# Matrix multiply.
if self.quant_method.is_monolithic:
assert has_separate_shared_experts or self.shared_experts is None
final_hidden_states = self.quant_method.apply_monolithic(
layer=layer,
x=staged_hidden_states,
router_logits=staged_router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=staged_hidden_states,
router_logits=staged_router_logits,
)
if self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().reduce_scatter(
hidden_states,
dim=0,
)
# need RS for shared_output?
final_hidden_states = self.quant_method.apply(
layer=layer,
x=staged_hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_input,
)
if self.shared_experts is not None:
assert shared_output is not None
return shared_output, hidden_states
else:
return hidden_states
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed)
if self.shared_experts is not None:
original_hidden_states = hidden_states
else:
original_hidden_states = None
shared_output = self.shared_experts(shared_input)
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states = self.apply_routed_input_transform(hidden_states)
final_hidden_states = (
shared_output,
final_hidden_states,
)
hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(
original_hidden_states,
hidden_states,
)
if not skip_result_store:
if self.shared_experts is None:
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states, non_blocking=True
)
else:
full_shared_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states[0], non_blocking=True
)
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states[1], non_blocking=True
)
fused_output = self.moe_forward(
hidden_states,
router_logits,
original_hidden_states,
self._encode_layer_name(),
)
return self._maybe_reduce_output(fused_output, og_hidden_dims)
def _slice_and_copy_input(
self,
out_slice: torch.Tensor,
orig: torch.Tensor | None,
start: int,
end: int,
) -> torch.Tensor:
assert orig is not None
slice_size = end - start
orig_slice = orig[start:end, :]
if self.enable_dbo:
assert out_slice.dim() == 3
batch_buffer_idx = dbo_current_ubatch_id()
out_slice = out_slice[batch_buffer_idx, :]
assert out_slice.size(0) >= slice_size
out_slice = out_slice[:slice_size, :]
out_slice.copy_(orig_slice, non_blocking=True)
return out_slice
def forward_impl_chunked(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# Gate overlap not supported when chunking is enabled. Run the
# gate first.
router_logits = self._maybe_gate(hidden_states, router_logits)
final_shared_hidden_states, final_fused_hidden_states = (
self._allocate_dp_chunking_outputs(hidden_states, router_logits)
)
ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
......@@ -564,7 +708,7 @@ class DefaultMoERunner(MoERunner):
max_tokens_across_dispatchers, self.moe_config.sp_size
)
num_tokens = full_hidden_states.size(0)
num_tokens = hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
):
......@@ -575,17 +719,55 @@ class DefaultMoERunner(MoERunner):
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
with ctx.dp_metadata.chunked_sizes(
chunk_sizes = ctx.dp_metadata.chunked_sizes(
self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
):
process_chunk(
chunk_start, chunk_end, skip_result_store=chunk_start_ >= num_tokens
)
with chunk_sizes:
hidden_states_chunk = self._slice_and_copy_input(
self.batched_hidden_states,
hidden_states,
chunk_start,
chunk_end,
)
router_logits_chunk = self._slice_and_copy_input(
self.batched_router_logits,
router_logits,
chunk_start,
chunk_end,
)
shared_input_chunk = (
shared_input[chunk_start:chunk_end, :]
if shared_input is not None
else None
)
shared_output_chunk, hidden_states_chunk = self._apply_quant_method(
layer=layer,
hidden_states=hidden_states_chunk,
router_logits=router_logits_chunk,
shared_input=shared_input_chunk,
)
# Store outputs
# TODO(bnell): document when chunk_start >= num_tokens
if chunk_start < num_tokens:
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
hidden_states_chunk, non_blocking=True
)
if self.shared_experts is not None:
assert shared_output_chunk is not None
assert final_shared_hidden_states is not None
final_shared_hidden_states[chunk_start:chunk_end, :].copy_(
shared_output_chunk, non_blocking=True
)
if self.shared_experts is None:
return full_fused_final_hidden_states
return final_fused_hidden_states
else:
return (full_shared_final_hidden_states, full_fused_final_hidden_states)
assert final_shared_hidden_states is not None
return (final_shared_hidden_states, final_fused_hidden_states)
def forward_impl(
self,
......@@ -594,148 +776,51 @@ class DefaultMoERunner(MoERunner):
router_logits: torch.Tensor,
shared_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
self.ensure_dp_chunking_init()
has_separate_shared_experts = (
not self.quant_method.mk_owns_shared_expert
and self.shared_experts is not None
self.use_shared_experts_stream = (
current_platform.is_cuda()
and self.has_separate_shared_experts
and not self.use_dp_chunking
and self.shared_experts_stream is not None
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
use_chunked_impl = self.use_dp_chunking
# Check if we need to run shared experts before matrix multiply because
# matrix multiply may modify the hidden_states.
run_shared_experts_before = (
self.has_separate_shared_experts and not self.use_shared_experts_stream
)
use_shared_experts_stream, shared_experts_input = (
# The shared experts stream must be set up before calling the gate so they
# can be overlapped.
if not run_shared_experts_before:
self._maybe_setup_shared_experts_stream(
hidden_states,
shared_input,
has_separate_shared_experts,
use_chunked_impl,
)
)
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
if use_chunked_impl:
return self.forward_impl_chunked(
layer,
hidden_states,
router_logits,
shared_input,
has_separate_shared_experts,
)
router_logits = self._maybe_gate(hidden_states, router_logits)
# NOTE(rob): once we finish migrating all the quant methods to use
# MKs, we can remove the naive dispatch/combine path from here.
do_naive_dispatch_combine = (
self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk
# TODO(bnell): parts of the dispatch/combine steps will go away once
# #32567 lands and the remaining kernels are made MKs. The PCP
# code will probably remain
hidden_states, router_logits = self._maybe_dispatch(
layer,
hidden_states,
router_logits,
)
ctx = get_forward_context()
sp_ctx = (
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
if ctx.dp_metadata
else nullcontext()
shared_output, hidden_states = self._apply_quant_method(
layer=layer,
hidden_states=hidden_states,
router_logits=router_logits,
shared_input=shared_input,
run_shared_experts_before=run_shared_experts_before,
)
with sp_ctx:
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_input = (
shared_input if shared_input is not None else hidden_states
)
shared_output = self.shared_experts(shared_input)
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
# router logits to all experts.
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
hidden_states,
router_logits,
self.moe_config.is_sequence_parallel,
)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstract to better support PCP.
if self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states,
dim=0,
)
router_logits = get_pcp_group().all_gather(
router_logits,
dim=0,
)
# Matrix multiply.
if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=layer,
x=hidden_states,
router_logits=router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
final_hidden_states = self.quant_method.apply(
layer=layer,
x=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_input,
)
if has_separate_shared_experts:
assert self.shared_experts is not None
if use_shared_experts_stream:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(shared_experts_input)
current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
final_hidden_states,
)
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(
states, self.moe_config.is_sequence_parallel
)
if self.moe_config.pcp_size > 1:
states = get_pcp_group().reduce_scatter(
states,
dim=0,
)
return states
if self.shared_experts is not None:
return (
final_hidden_states[0],
combine_output(final_hidden_states[1]),
)
else:
return combine_output(final_hidden_states)
return self._maybe_combine(
shared_output,
hidden_states,
)
......@@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
per_tensor_dequantize,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer
......@@ -199,7 +200,7 @@ def _mxfp8_e4m3_quantize(
) -> tuple[torch.Tensor, torch.Tensor]:
assert A_scale is None
assert not per_act_token_quant
assert block_shape is None
assert block_shape is None or block_shape == [1, 32]
return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout)
......@@ -265,7 +266,7 @@ def moe_kernel_quantize_input(
# weights are already dequantized, and we proceed with normal
# activation quantization below.
if quant_dtype == torch.float8_e4m3fn:
if quant_dtype == current_platform.fp8_dtype():
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
......@@ -316,27 +317,6 @@ def normalize_batched_scales_shape(
return scales
def _validate_scale_shape(
a: torch.Tensor,
a_scale: torch.Tensor | None,
per_act_token_quant: bool,
block_shape: list[int] | None,
) -> None:
if a_scale is None:
return
if not per_act_token_quant and block_shape is None:
assert a_scale.numel() == 1, f"{a_scale.shape}"
elif per_act_token_quant:
assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, (
f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1"
)
else:
assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
# Torch custom ops can't deal with outputs aliasing inputs so we need to
# disable inplace for torch >= 2.9.
# See https://github.com/vllm-project/vllm/issues/26378
......
......@@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
num_actual_tokens = attn_metadata.num_actual_tokens
constant_caches = self.kv_cache[forward_context.virtual_engine]
constant_caches = self.kv_cache[0]
q_proj_states = q_proj_states[:num_actual_tokens]
k_proj_states = k_proj_states[:num_actual_tokens]
......
......@@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
kv_cache = self.kv_cache[0][0]
state_indices_tensor = attn_metadata.state_indices_tensor
clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata
......
......@@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer):
query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
......
......@@ -574,7 +574,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
......
......@@ -333,13 +333,13 @@ def selective_state_update(
dt_bias = dt_bias.unsqueeze(0)
if out.dim() == 2:
out = out.unsqueeze(1)
if num_accepted_tokens is not None:
assert state_batch_indices is not None and state_batch_indices.dim() == 2
assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2
if state_batch_indices is not None and state_batch_indices.dim() == 1:
state_batch_indices = state_batch_indices.unsqueeze(1)
if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1:
dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1)
if num_accepted_tokens is not None:
assert state_batch_indices is not None and state_batch_indices.dim() == 2
assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2
_, nheads, dim, dstate = state.shape
batch = x.shape[0]
......
......@@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
......
......@@ -16,25 +16,22 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
def get_classification_act_fn(
def get_act_fn(
config: PretrainedConfig,
static_num_labels: bool = True,
) -> "PoolerActivation":
# get classification act_fn
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type = getattr(config, "problem_type", "")
if problem_type == "regression":
return PoolerIdentity()
if problem_type == "single_label_classification":
return PoolerClassify()
return PoolerClassify(static_num_labels=static_num_labels)
if problem_type == "multi_label_classification":
return PoolerMultiLabelClassify()
return PoolerClassify()
def get_cross_encoder_act_fn(
config: PretrainedConfig,
) -> "PoolerActivation":
# get cross_encoder act_fn
function_name: str | None = None
if (
hasattr(config, "sentence_transformers")
......@@ -55,24 +52,16 @@ def get_cross_encoder_act_fn(
fn = resolve_obj_by_qualname(function_name)()
return PoolerActivation.wraps(fn)
return PoolerClassify()
return PoolerClassify(static_num_labels=static_num_labels)
def resolve_classifier_act_fn(
model_config: ModelConfig,
static_num_labels: bool = True,
act_fn: "PoolerActivation | str | None" = None,
act_fn: "PoolerActivation | None" = None,
):
if isinstance(act_fn, str):
if act_fn == "classify":
return get_classification_act_fn(model_config.hf_config)
if act_fn == "score":
return get_cross_encoder_act_fn(model_config.hf_config)
raise ValueError(f"act_fn [{act_fn=}] not supported.")
if act_fn is None:
return PoolerClassify(static_num_labels=static_num_labels)
return get_act_fn(model_config.hf_config, static_num_labels)
assert callable(act_fn)
return act_fn
......@@ -97,9 +86,8 @@ class PoolerActivation(nn.Module, ABC):
def forward(self, pooled_data: _T) -> _T:
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
# classify -> (batch_size, num_classes)
# embed -> (batch_size, embedding_size) or list(embedding_size)
if isinstance(pooled_data, list):
return [self.forward_chunk(data) for data in pooled_data]
......
......@@ -56,29 +56,31 @@ class EmbeddingPoolerHead(SequencePoolerHead):
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
# pooled_data shape: [batchsize, hidden_size]
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# Apply ST projector
if self.projector is not None:
pooled_data = self.projector(pooled_data)
# pooled_data shape: [batchsize, embedding_dimension]
embeddings = self.projector(pooled_data)
else:
embeddings = pooled_data
# embeddings shape: [batchsize, embedding_size]
# for matryoshka representation
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
if any(d is not None for d in dimensions_list):
# change the output dimension
assert len(pooled_data) == len(dimensions_list)
if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
assert len(embeddings) == len(dimensions_list)
if len(set(dimensions_list)) == 1 and not isinstance(embeddings, list):
# if all dimensions are the same
d = dimensions_list[0]
pooled_data = pooled_data[..., :d]
embeddings = embeddings[..., :d]
else:
pooled_data = [
embeddings = [
vecs if d is None else vecs[..., :d]
for vecs, d in zip(pooled_data, dimensions_list)
for vecs, d in zip(embeddings, dimensions_list)
]
# for normalize
......@@ -86,15 +88,15 @@ class EmbeddingPoolerHead(SequencePoolerHead):
flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
embeddings = self.activation(embeddings)
else:
pooled_data = [
embeddings = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
for vecs, f in zip(embeddings, flags)
]
# pooled_data shape: [batchsize, embedding_dimension]
return pooled_data
# embeddings shape: [batchsize, embedding_size]
return embeddings
class ClassifierPoolerHead(SequencePoolerHead):
......@@ -113,7 +115,7 @@ class ClassifierPoolerHead(SequencePoolerHead):
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"}
return {"classify"}
def forward(
self,
......@@ -131,21 +133,23 @@ class ClassifierPoolerHead(SequencePoolerHead):
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None:
pooled_data = self.classifier(pooled_data)
# pooled_data shape: [batchsize, num_labels]
logits = self.classifier(pooled_data)
else:
logits = pooled_data
# logits shape: [batchsize, num_labels]
if self.logit_bias is not None:
pooled_data -= self.logit_bias
logits -= self.logit_bias
if self.activation is not None:
flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
pooled_data = self.activation(pooled_data) if flags[0] else pooled_data
logits = self.activation(logits) if flags[0] else logits
else:
pooled_data = [
logits = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
for vecs, f in zip(logits, flags)
]
# pooled_data shape: [batchsize, num_labels]
return pooled_data
# logits shape: [batchsize, num_labels]
return logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment