Unverified Commit 4d51588e authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Feat] DeepSeek V4 Rebased (#40860)


Signed-off-by: default avatarYifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarqizixi <zixi@inferact.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <yongye@inferact.ai>
Co-authored-by: default avatarSimon Mo <simon@inferact.ai>
Co-authored-by: default avatarBugen Zhao <i@bugenzhao.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoy Wang <yasong.wang@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarZhewen Li <jerven.vllm@gmail.com>
Co-authored-by: default avatarZijing Liu <liuzijing2014@gmail.com>
Co-authored-by: default avatarkhluu <khluu000@gmail.com>
Co-authored-by: default avatarqizixi <zixi@inferact.ai>
Co-authored-by: default avatarZhewen Li <zhewenli@inferact.ai>
parent 32e45636
...@@ -299,6 +299,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): ...@@ -299,6 +299,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
tools: list[ChatCompletionFunctionToolParam] | None tools: list[ChatCompletionFunctionToolParam] | None
"""The tools for developer role.""" """The tools for developer role."""
task: str | None
"""Model-specific task marker. Currently passed through for DeepSeek V4."""
ChatCompletionMessageParam: TypeAlias = ( ChatCompletionMessageParam: TypeAlias = (
OpenAIChatCompletionMessageParam OpenAIChatCompletionMessageParam
...@@ -333,6 +336,9 @@ class ConversationMessage(TypedDict, total=False): ...@@ -333,6 +336,9 @@ class ConversationMessage(TypedDict, total=False):
tools: list[ChatCompletionFunctionToolParam] | None tools: list[ChatCompletionFunctionToolParam] | None
"""The tools for developer role.""" """The tools for developer role."""
task: str | None
"""Model-specific task marker. Currently passed through for DeepSeek V4."""
# Passed in by user # Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
...@@ -1566,6 +1572,9 @@ def _parse_chat_message_content( ...@@ -1566,6 +1572,9 @@ def _parse_chat_message_content(
if "name" in message and isinstance(message["name"], str): if "name" in message and isinstance(message["name"], str):
result_msg["name"] = message["name"] result_msg["name"] = message["name"]
if "task" in message and isinstance(message["task"], str):
result_msg["task"] = message["task"]
if role == "developer": if role == "developer":
result_msg["tools"] = message.get("tools", None) result_msg["tools"] = message.get("tools", None)
return result return result
......
...@@ -100,6 +100,8 @@ class DeepGemmFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel): ...@@ -100,6 +100,8 @@ class DeepGemmFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
else params.weight_scale, else params.weight_scale,
quant_block_shape=tuple(layer.weight_block_size), quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=self.use_deep_gemm_e8m0, use_e8m0=self.use_deep_gemm_e8m0,
is_bmm=getattr(layer, "is_bmm", False),
bmm_batch_size=getattr(layer, "bmm_batch_size", 0),
) )
replace_parameter(layer, params.WEIGHT, dg_weight) replace_parameter(layer, params.WEIGHT, dg_weight)
replace_parameter(layer, scale_attr, dg_weight_scale) replace_parameter(layer, scale_attr, dg_weight_scale)
......
...@@ -1422,6 +1422,20 @@ class MLADims: ...@@ -1422,6 +1422,20 @@ class MLADims:
def get_mla_dims(model_config: ModelConfig) -> MLADims: def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config hf_text_config = model_config.hf_text_config
# Check if this is a DeepseekV4 config (uses unified head_dim + rope_head_dim)
if hasattr(hf_text_config, "compress_ratios"):
# DeepseekV4 style config: unified head_dim with rope_head_dim
head_dim = hf_text_config.head_dim
rope_head_dim = hf_text_config.qk_rope_head_dim
return MLADims(
q_lora_rank=hf_text_config.q_lora_rank,
kv_lora_rank=head_dim,
qk_nope_head_dim=head_dim - rope_head_dim,
qk_rope_head_dim=rope_head_dim,
v_head_dim=head_dim,
)
# DeepseekV2/V3 style config
return MLADims( return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank, kv_lora_rank=hf_text_config.kv_lora_rank,
...@@ -2191,6 +2205,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -2191,6 +2205,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
qk_head_dim: int, qk_head_dim: int,
v_head_dim: int, v_head_dim: int,
kv_b_proj: ColumnParallelLinear, kv_b_proj: ColumnParallelLinear,
# DSV3.2 MLA Specific Arguments
indexer: object | None = None, indexer: object | None = None,
q_pad_num_heads: int | None = None, q_pad_num_heads: int | None = None,
) -> None: ) -> None:
...@@ -2213,6 +2228,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -2213,6 +2228,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.indexer = indexer self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads self.q_pad_num_heads = q_pad_num_heads
self.supports_quant_query_input = True self.supports_quant_query_input = True
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
# Use flashinfer's optimized concat_mla_k kernel when available. # Use flashinfer's optimized concat_mla_k kernel when available.
# The kernel is optimized for DeepSeek V3 dimensions: # The kernel is optimized for DeepSeek V3 dimensions:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, ClassVar, cast
import torch
from torch import nn
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
)
from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import (
_fused_kv_compress_norm_rope_insert_indexer_attn,
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
_fused_kv_compress_norm_rope_insert_sparse_attn,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
MXFP4_BLOCK_SIZE,
)
from vllm.v1.kv_cache_interface import (
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowMLASpec,
)
class CompressorBackend(AttentionBackend):
def __init__(self):
super().__init__()
@staticmethod
def get_name() -> str:
return "CompressorBackend"
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(1)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [512, 1024]
@staticmethod
def get_builder_cls() -> type["CompressorMetadataBuilder"]:
return CompressorMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
assert num_kv_heads == 1
return (num_blocks, block_size, head_size)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
if include_num_layers_dimension:
return (0, 1, 2, 3)
return (0, 1, 2)
@dataclass
class CompressorMetadata:
block_table: torch.Tensor
slot_mapping: torch.Tensor
block_size: int
token_to_req_indices: torch.Tensor | None = None # [num_tokens]
class CompressorMetadataBuilder(AttentionMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec)
mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec)
self.block_size = mla_spec.block_size
self.token_to_req_indices = torch.zeros(
self.vllm_config.scheduler_config.max_num_batched_tokens,
dtype=torch.int32,
device=self.device,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> CompressorMetadata:
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_reqs = common_attn_metadata.num_reqs
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
x = torch.repeat_interleave(torch.arange(num_reqs), query_lens).pin_memory()
token_to_req_indices = self.token_to_req_indices[: x.shape[0]]
token_to_req_indices.copy_(x, non_blocking=True)
return CompressorMetadata(
block_table=common_attn_metadata.block_table_tensor.clamp_(min=0),
slot_mapping=common_attn_metadata.slot_mapping,
block_size=self.block_size,
token_to_req_indices=token_to_req_indices,
)
class CompressorStateCache(torch.nn.Module, AttentionLayerBase):
def __init__(
self,
state_dim: int,
dtype: torch.dtype,
compress_ratio: int,
prefix: str,
):
super().__init__()
self.state_dim = state_dim
self.dtype = dtype
self.prefix = prefix
self.kv_cache = torch.tensor([])
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
assert self.dtype == torch.float32
assert compress_ratio in [4, 128]
coff = 1 + (compress_ratio == 4)
self.sliding_window = coff * compress_ratio
# Block size is constrained by tensor sharing between compressor states
# and KV blocks. Since compressor states share the same physical tensor
# as KV blocks, they must use the same page size.
# The KV block shape [256//4, head_dim] = [64, 584] determines:
# - C4 compressor block shape [4, 2*512*2*4] -> block_size = 4
# - C128 compressor block shape [8, 512*2*4] -> block_size = 8
# TODO(yifan): make block size automatically determined and configurable.
if compress_ratio == 4:
self.block_size = 4
elif compress_ratio == 128:
self.block_size = 8
else:
raise ValueError(f"Invalid compress ratio: {compress_ratio}")
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return SlidingWindowMLASpec( # only has one vector instead of K + V
block_size=self.block_size,
num_kv_heads=1,
head_size=self.state_dim,
dtype=self.dtype,
sliding_window=self.sliding_window,
alignment=576, # NOTE: FlashMLA requires 576B alignment
)
def forward(self): ...
def get_attn_backend(self) -> type[AttentionBackend]:
return CompressorBackend
class DeepseekCompressor(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
compress_ratio: int,
hidden_size: int,
head_dim: int,
rotate: bool = False,
prefix: str = "",
k_cache_prefix="",
use_fp4_cache: bool = False,
):
super().__init__()
self.compress_ratio = compress_ratio
self.hidden_size = hidden_size
self.head_dim = head_dim
self.rotate = rotate
self.prefix = prefix
self.k_cache_prefix = k_cache_prefix
self.use_fp4_cache = use_fp4_cache
config = vllm_config.model_config.hf_config
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.rms_norm_eps = config.rms_norm_eps
self.device = current_platform.device_type
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.max_model_len = vllm_config.model_config.max_model_len
self.overlap = compress_ratio == 4
self.coff = 1 + self.overlap
state_dtype = torch.float32
self.ape = nn.Parameter(
torch.empty(
(compress_ratio, self.coff * self.head_dim),
dtype=state_dtype,
device=self.device,
),
requires_grad=False,
)
self.fused_wkv_wgate = MergedColumnParallelLinear(
self.hidden_size,
[self.coff * self.head_dim, self.coff * self.head_dim],
bias=False,
return_bias=False,
quant_config=None,
disable_tp=True,
prefix=f"{prefix}.fused_wkv_wgate",
)
self.norm = RMSNorm(self.head_dim, self.rms_norm_eps)
self.state_cache = CompressorStateCache(
state_dim=2 * self.coff * self.head_dim, # kv_state + score_state
dtype=state_dtype,
compress_ratio=compress_ratio,
prefix=f"{prefix}.state_cache",
)
# Save reference to static_forward_context for forward-time KV cache lookup.
# get_current_vllm_config() is only available during __init__, not forward.
self._static_forward_context = (
vllm_config.compilation_config.static_forward_context
)
if self.head_dim == 512:
assert not use_fp4_cache, (
"MXFP4 cache is only supported for indexer (head=128)"
)
self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
self._quant_block = 64
self._token_stride = self.nope_head_dim + self.rope_head_dim * 2
self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad
self._num_warps = 4
elif self.head_dim == 128:
if use_fp4_cache:
self._fused_kernel = (
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
)
self._quant_block = MXFP4_BLOCK_SIZE
self._token_stride = self.head_dim // 2
self._scale_dim = self.head_dim // MXFP4_BLOCK_SIZE
else:
self._fused_kernel = _fused_kv_compress_norm_rope_insert_indexer_attn
self._quant_block = 128
self._token_stride = self.head_dim
self._scale_dim = 4 # single float32 scale
self._num_warps = 1
else:
raise ValueError(
f"Unsupported head_dim for fused quant+cache: {self.head_dim}"
)
def forward(
self,
# [num_tokens, hidden_size]
x: torch.Tensor,
# [num_tokens]
positions: torch.Tensor,
rotary_emb,
) -> None:
num_tokens, _ = x.shape
# bf16 weights/activations but fp32 output for numerical stability of
# the downstream compressor math.
kv_score = cublas_gemm_bf16_bf16_fp32(x, self.fused_wkv_wgate.weight)
# Each of shape [num_tokens, coff * self.head_dim]
# input bf16, output are fp32
kv, score = kv_score.split(
[self.coff * self.head_dim, self.coff * self.head_dim], dim=-1
)
# Get the metadata and handle dummy profiling run.
attn_metadata = get_forward_context().attn_metadata
if not isinstance(attn_metadata, dict):
return
state_metadata = cast(
CompressorMetadata, attn_metadata[self.state_cache.prefix]
)
token_to_req_indices = state_metadata.token_to_req_indices
slot_mapping = state_metadata.slot_mapping
num_actual = slot_mapping.shape[0]
block_table = state_metadata.block_table
block_size = state_metadata.block_size
# [num_blocks, block_size, kv_dim+score_dim], where kv_dim == score_dim
state_cache = self.state_cache.kv_cache
# kv_state stored in first half, score_state stored in second half
state_width = state_cache.shape[-1] // 2
# Store the KV and score (with fused APE addition) in the state.
# NOTE: PDL is disabled — both this kernel and _fused_kernel below
# depend on preceding kernel outputs (kv/score from the cublas GEMM;
# state_cache from this kernel) but neither emits/waits on PDL grid
# dependency primitives, so launch_pdl=True caused a read-after-write
# race and non-deterministic output.
_save_partial_states_kernel[(num_actual,)](
kv,
kv.stride(0),
score,
score.stride(0),
self.ape,
self.ape.stride(0),
positions,
state_cache,
state_cache.stride(0),
state_cache.stride(1),
slot_mapping,
block_size,
HEAD_SIZE=kv.shape[-1],
TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
launch_pdl=False,
)
# Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write.
# RoPE requirements (kernel applies forward GPT-J style rotation):
# - is_neox_style=False (interleaved pairs, NOT split-half)
# - cos_sin_cache layout: [max_pos, rope_head_dim] with first half cos,
# second half sin (per-pair, length rope_head_dim // 2 each)
# - applied to LAST rope_head_dim elements of head_dim
# - position used: (positions // compress_ratio) * compress_ratio
cos_sin_cache = rotary_emb.cos_sin_cache
k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix])
kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache
self._fused_kernel[(num_actual,)](
# state cache
state_cache,
state_cache.stride(0),
state_cache.stride(1),
# metadata
token_to_req_indices,
positions,
slot_mapping,
block_table,
block_table.stride(0),
block_size,
# RMSNorm
self.norm.weight,
self.rms_norm_eps,
# RoPE
cos_sin_cache,
cos_sin_cache.stride(0),
# KV cache
kv_cache,
k_cache_metadata.slot_mapping,
kv_cache.shape[1], # paged KV cache block size (tokens per block)
# constexprs
HEAD_SIZE=self.head_dim,
TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
OVERLAP=self.overlap,
ROPE_HEAD_DIM=self.rope_head_dim,
FP8_MAX=448.0,
QUANT_BLOCK=self._quant_block,
TOKEN_STRIDE=self._token_stride,
SCALE_DIM=self._scale_dim,
KV_BLOCK_STRIDE=kv_cache.stride(0),
num_warps=self._num_warps,
launch_pdl=False,
)
@triton.jit
def _save_partial_states_kernel(
kv_ptr,
kv_stride,
score_ptr,
score_stride,
ape_ptr,
ape_stride,
positions_ptr,
state_cache_ptr,
state_cache_stride0,
state_cache_stride1,
slot_mapping_ptr,
block_size,
HEAD_SIZE: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
# state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide.
STATE_WIDTH: tl.constexpr,
COMPRESS_RATIO: tl.constexpr,
):
token_idx = tl.program_id(0)
slot_id = tl.load(slot_mapping_ptr + token_idx)
# Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used
# by vLLM). During CUDA graph replay the batch may contain padding
# tokens whose slot_mapping is -1; writing to kv_state[-1] would be an
# illegal memory access.
if slot_id < 0:
return
block_idx = slot_id // block_size
pos_in_block = slot_id % block_size
base_ptr = (
state_cache_ptr
+ block_idx * state_cache_stride0
+ pos_in_block * state_cache_stride1
)
block = tl.arange(0, TRITON_BLOCK_SIZE)
mask = block < HEAD_SIZE
kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask)
tl.store(base_ptr + block, kv, mask=mask)
# Fused: score += ape[position % compress_ratio]
position = tl.load(positions_ptr + token_idx)
ape_row = position % COMPRESS_RATIO
ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask)
score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask)
tl.store(
base_ptr + STATE_WIDTH + block,
score + ape,
mask=mask,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
DeepseekV4 MLA Attention Layer
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import DeepseekV2Config, DeepseekV3Config
from vllm.model_executor.layers.linear import (
ReplicatedLinear,
)
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
from vllm.utils.deep_gemm import fp8_einsum
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.ops.deepseek_v4_ops import (
combine_topk_swa_indices,
compute_global_topk_indices_and_lens,
dequantize_and_gather_k_cache,
fused_indexer_q_rope_quant,
fused_inv_rope_fp8_quant,
fused_q_kv_rmsnorm,
)
if TYPE_CHECKING:
from vllm.v1.attention.backends.mla.sparse_swa import (
DeepseekSparseSWAMetadata,
)
from vllm.config import (
CacheConfig,
VllmConfig,
get_current_vllm_config,
)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
QuantFP8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.utils.multi_stream_utils import maybe_execute_in_parallel
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backends.mla.flashmla_sparse import (
DeepseekV4FlashMLASparseBackend,
FlashMLASparseBackend,
FlashMLASparseMetadata,
)
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV4IndexerBackend,
get_max_prefill_buffer_size,
)
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache
from vllm.v1.attention.ops.flashmla import (
flash_mla_sparse_fwd,
flash_mla_with_kvcache,
)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
logger = init_logger(__name__)
# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather
# workspace allocated at _forward_prefill (and the matching profile-time
# reservation in attention_impl's dummy-run branch).
PREFILL_CHUNK_SIZE = 4
@dataclass
class DeepseekV4MLAModules:
"""Modules used in DeepseekV4 MLA."""
vllm_config: VllmConfig
fused_wqa_wkv: torch.nn.Module
q_norm: torch.nn.Module
wq_b: torch.nn.Module
kv_norm: torch.nn.Module
wo_a: torch.nn.Module
wo_b: torch.nn.Module
attn_sink: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: torch.nn.Module | None
indexer_rotary_emb: torch.nn.Module
topk_indices_buffer: torch.Tensor | None
aux_stream: torch.cuda.Stream | None = None
# --8<-- [start:multi_head_latent_attention]
@PluggableLayer.register("deepseek_v4_multi_head_latent_attention")
class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
"""Pluggable MLA layer which allows OOT backends to add
custom implementations of the outer MLA layer (including rope & o_proj).
Note that currently oot platforms can still use CustomOp.register_oot to
replace MLA layer entirely, although we use PluggableLayer to register
this layer now.
This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens.
The class does the following:
1. MLA Preprocess.
2. Perform multi-head attention to prefill tokens and
multi-query attention to decode tokens separately.
3. Return the output tensor.
"""
# --8<-- [end:multi_head_latent_attention]
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
o_lora_rank: int | None,
mla_modules: DeepseekV4MLAModules,
window_size: int,
compress_ratio: int | None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.n_local_heads = num_heads
self.head_dim = head_dim
self.scale = scale
# FlashMLA sparse kernel only supports 64 or 128 heads; pad up to the
# next supported size. Must match DeepseekV4MLAAttention.padded_heads.
if num_heads <= 64:
self.padded_heads = 64
elif num_heads <= 128:
self.padded_heads = 128
else:
raise ValueError(
f"DeepseekV4 attention does not support {num_heads} heads "
"(must be <= 128)."
)
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.window_size = window_size
self.compress_ratio = compress_ratio if compress_ratio is not None else 1
self.prefix = prefix
# Extract config from vllm_config
config = mla_modules.vllm_config.model_config.hf_config
tp_size = get_tensor_model_parallel_world_size()
# DeepseekV4-specific attributes (num_heads is already TP-adjusted)
self.eps = config.rms_norm_eps
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = head_dim - self.rope_head_dim
self.n_local_groups = config.o_groups // tp_size
self.o_lora_rank = config.o_lora_rank
# Store projection modules
self.fused_wqa_wkv = mla_modules.fused_wqa_wkv
self.q_norm = mla_modules.q_norm
self.wq_b = mla_modules.wq_b
self.kv_norm = mla_modules.kv_norm
self.wo_a = mla_modules.wo_a
self._wo_a_act_quant = QuantFP8(
static=False,
group_shape=GroupShape(1, 128),
use_ue8m0=True,
)
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
# INT32) so fp8_einsum can handle layout transform internally.
self._wo_a_act_quant.use_deep_gemm_supported = False
self.wo_b = mla_modules.wo_b
# Pick fp8_einsum recipe based on GPU arch:
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
# SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1
from vllm.platforms import current_platform
cap = current_platform.get_device_capability()
assert cap is not None, "DeepseekV4 attention requires a CUDA device"
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
self._tma_aligned_scales = cap.major >= 10
self.rotary_emb = mla_modules.rotary_emb
self.indexer_rotary_emb = mla_modules.indexer_rotary_emb
self.topk_indices_buffer = mla_modules.topk_indices_buffer
self.indexer = mla_modules.indexer
# Per-head RMS normalization for Q (no learnable weights)
self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False)
# TODO(yifan): currently hardcoded for FP8 sparse, make it more generic
head_bytes = (
self.nope_head_dim # 448 fp8 NoPE
+ self.rope_head_dim * 2 # 64 bf16 RoPE
+ self.nope_head_dim // 64 # 7B scale factors
+ 1 # 1B pad
)
self.aux_stream = mla_modules.aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
assert cache_config is not None, "DeepseekV4 attention requires cache_config"
self.swa_cache_layer = DeepseekV4SWACache(
head_dim=self.head_dim,
window_size=self.window_size,
dtype=torch.uint8,
prefix=f"{prefix}.swa_cache",
cache_config=cache_config,
)
self.mla_attn = DeepseekV4MLAAttention(
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
compress_ratio=self.compress_ratio,
window_size=self.window_size,
head_bytes=head_bytes,
swa_cache_layer=self.swa_cache_layer,
attn_sink=mla_modules.attn_sink, # already padded with -inf
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
indexer=self.indexer,
topk_indices_buffer=self.topk_indices_buffer,
)
# Register this layer in the compilation config's static forward context
# This allows the custom op to retrieve the layer during execution
compilation_config = mla_modules.vllm_config.compilation_config
# HACK
self.layer_name = prefix + ".deepseek_v4_multi_head_latent_attention"
if self.layer_name in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {self.layer_name}")
compilation_config.static_forward_context[self.layer_name] = self
# Create the compressor for layers with compress_ratio > 1; after
# creating the DeepseekV4MLAAttention layer to get its cache.
self.compressor = None
if self.compress_ratio > 1:
self.compressor = DeepseekCompressor(
vllm_config=mla_modules.vllm_config,
compress_ratio=self.compress_ratio,
hidden_size=self.hidden_size,
head_dim=self.head_dim,
rotate=True,
prefix=f"{prefix}.compressor",
k_cache_prefix=self.mla_attn.prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
qr_kv, _ = self.fused_wqa_wkv(hidden_states)
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
# Pre-allocate attention output with FlashMLA-padded head count.
# The op writes into `o_padded`; we slice to n_local_heads after.
num_tokens = hidden_states.shape[0]
o_padded = torch.empty(
(num_tokens, self.padded_heads, self.head_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
# Attention (inside custom op for torch.compile boundary)
torch.ops.vllm.deepseek_v4_attention(
hidden_states,
qr,
kv,
positions,
o_padded,
self.layer_name,
)
o = o_padded[:, : self.n_local_heads, :]
# O projection: inverse RoPE + FP8 quant + einsum + wo_b
o_fp8, o_scale = fused_inv_rope_fp8_quant(
o,
positions,
self.rotary_emb.cos_sin_cache,
n_groups=self.n_local_groups,
heads_per_group=self.n_local_heads // self.n_local_groups,
nope_dim=self.nope_head_dim,
rope_dim=self.rope_head_dim,
tma_aligned_scales=self._tma_aligned_scales,
)
wo_a_fp8 = self.wo_a.weight
wo_a_scale = self.wo_a.weight_scale_inv
z = torch.empty(
(num_tokens, self.n_local_groups, self.o_lora_rank),
device=o.device,
dtype=torch.bfloat16,
)
torch.ops.vllm.deepseek_v4_fp8_einsum(
o_fp8,
o_scale,
wo_a_fp8,
wo_a_scale,
z,
"bhr,hdr->bhd",
list(self._einsum_recipe),
)
return self.wo_b(z.flatten(1))
def attention_impl(
self,
hidden_states: torch.Tensor,
qr: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place
) -> None:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
qr, kv = fused_q_kv_rmsnorm(
qr,
kv,
self.q_norm.weight.data,
self.kv_norm.weight.data,
self.eps,
)
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
# Overlap kv_insert with whichever of indexer/compressor is present.
# Indexer implies compressor; when both exist, compressor rides on the
# aux stream alongside kv_insert so the heavy indexer owns default.
if self.indexer is not None:
indexer = self.indexer
# Local ref so the closure keeps a non-None type for mypy.
assert self.compressor is not None
compressor = self.compressor
def kv_insert_and_compress() -> None:
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
compressor(hidden_states, positions, self.rotary_emb)
maybe_execute_in_parallel(
lambda: indexer(hidden_states, qr, positions, self.indexer_rotary_emb),
kv_insert_and_compress,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
elif self.compressor is not None:
# Compressor on default, kv_insert on aux.
compressor = self.compressor
maybe_execute_in_parallel(
lambda: compressor(hidden_states, positions, self.rotary_emb),
lambda: self._fused_qnorm_rope_kv_insert(
q, kv, positions, attn_metadata
),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
else:
# SWA-only layer: no compressor, no overlap.
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
# Handle dummy run (no metadata).
if not isinstance(attn_metadata, dict):
# Reserve _forward_prefill's bf16-gather workspace; the dummy
# run returns before mla_attn runs, so without this the shared
# workspace locks below the real prefill size.
sub = self.mla_attn
swa_only = sub.compress_ratio <= 1
N = (
0
if swa_only
else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio
)
M = N + sub.window_size + sub.max_num_batched_tokens
current_workspace_manager().get_simultaneous(
((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)
out.zero_()
return
# Pad q to FlashMLA-required head count (64 or 128)
if self.n_local_heads < self.padded_heads:
pad_size = self.padded_heads - self.n_local_heads
q = F.pad(q, (0, 0, 0, pad_size), value=0.0)
# MLA attention writes into the pre-allocated `out` buffer
# ([num_tokens, padded_heads, head_dim]).
self.mla_attn(q, kv, positions, output=out)
def _fused_qnorm_rope_kv_insert(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
attn_metadata: (
dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None
),
) -> None:
if not isinstance(attn_metadata, dict):
return
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
swa_kv_cache = self.swa_cache_layer.kv_cache
swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1)
# Horizontally fused:
# Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE
# KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert
# kv is unchanged; mla_attn reads kv solely via swa_kv_cache.
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
q,
kv,
swa_kv_cache_2d,
swa_metadata.slot_mapping,
positions.to(torch.int64),
self.rotary_emb.cos_sin_cache,
self.eps,
swa_metadata.block_size,
)
def deepseek_v4_attention(
hidden_states: torch.Tensor,
qr: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.attention_impl(hidden_states, qr, kv, positions, out)
def deepseek_v4_attention_fake(
hidden_states: torch.Tensor,
qr: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor,
layer_name: str,
) -> None:
return None
direct_register_custom_op(
op_name="deepseek_v4_attention",
op_func=deepseek_v4_attention,
mutates_args=["out"],
fake_impl=deepseek_v4_attention_fake,
)
def deepseek_v4_fp8_einsum(
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
out: torch.Tensor,
equation: str,
recipe: list[int],
) -> None:
fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe))
def deepseek_v4_fp8_einsum_fake(
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
out: torch.Tensor,
equation: str,
recipe: list[int],
) -> None:
return None
direct_register_custom_op(
op_name="deepseek_v4_fp8_einsum",
op_func=deepseek_v4_fp8_einsum,
mutates_args=["out"],
fake_impl=deepseek_v4_fp8_einsum_fake,
)
class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
# FlashMLA FP8 sparse only supports 64 or 128 heads
SUPPORTED_HEAD_COUNTS = (64, 128)
def __init__(
self,
num_heads: int,
head_dim: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
compress_ratio: int,
window_size: int,
head_bytes: int,
swa_cache_layer: DeepseekV4SWACache,
attn_sink: torch.Tensor,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
# Sparse MLA Args
indexer: object | None = None,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream: torch.cuda.Stream | None = None,
**extra_impl_args,
) -> None:
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = 1
self.head_dim = head_dim
self.scale = scale
self.window_size = window_size
self.head_bytes = head_bytes
self.compress_ratio = compress_ratio
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.nope_head_dim = qk_nope_head_dim
self.rope_head_dim = qk_rope_head_dim
self.indexer = indexer
self.topk_indices_buffer = topk_indices_buffer
self.prefix = prefix # Alias for compatibility with compressor
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
# Determine padded head count for FlashMLA
if num_heads not in self.SUPPORTED_HEAD_COUNTS:
if num_heads < 64:
self.padded_heads = 64
elif num_heads < 128:
self.padded_heads = 128
else:
raise ValueError(
f"DeepseekV4MLAAttention does not support {num_heads} heads. "
f"Supported: <= 128 (will be padded to 64 or 128)"
)
else:
self.padded_heads = num_heads
# Store attention sink
assert attn_sink is not None
self.attn_sink: torch.Tensor = attn_sink
# Store SWA cache
assert swa_cache_layer is not None
self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer
# Get vllm config for cache setup
vllm_config = get_current_vllm_config()
self.max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens
)
self.max_model_len = vllm_config.model_config.max_model_len
# DeepseekV4 only supports fp8 kv-cache format for now
kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8"
assert kv_cache_dtype.startswith("fp8"), (
f"DeepseekV4 only supports fp8 kv-cache format for now, "
f"got {kv_cache_dtype}"
)
assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), (
"Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now"
)
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if (
issubclass(self.get_attn_backend(), FlashMLASparseBackend)
and kv_cache_dtype.startswith("fp8")
and kv_cache_dtype != "fp8_ds_mla"
):
assert cache_config is not None
cache_config.cache_dtype = "fp8_ds_mla"
kv_cache_dtype = "fp8_ds_mla"
logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
self.kv_cache_dtype = kv_cache_dtype
# Register with compilation context for metadata lookup
compilation_config = vllm_config.compilation_config
if prefix and prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
if prefix:
compilation_config.static_forward_context[prefix] = self
self.kv_cache = torch.tensor([])
def get_attn_backend(self) -> type[AttentionBackend]:
return DeepseekV4FlashMLASparseBackend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
if (
self.compress_ratio <= 1
): # SWA part. Allocated separately as DeepseekV4SWACache.
return None
return MLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=torch.uint8,
compress_ratio=self.compress_ratio,
cache_dtype_str=self.kv_cache_dtype,
alignment=576, # NOTE: FlashMLA requires 576B alignment
model_version="deepseek_v4",
)
def forward(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
assert output.shape == q.shape, (
f"output buffer shape {output.shape} must match q shape {q.shape}"
)
assert output.dtype == q.dtype, (
f"output buffer dtype {output.dtype} must match q dtype {q.dtype}"
)
# Get SWA and indexer metadata from forward context
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
assert isinstance(attn_metadata, dict)
flashmla_metadata = cast(
FlashMLASparseMetadata | None, attn_metadata.get(self.prefix)
)
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
swa_only = self.compress_ratio <= 1
# SWA-only layers (compress_ratio <= 1) don't have their own KV cache
# allocation, so self.kv_cache may be empty after profiling cleanup.
self_kv_cache = self.kv_cache if not swa_only else None
swa_kv_cache = self.swa_cache_layer.kv_cache
# Split prefill and decode
num_decodes = swa_metadata.num_decodes
num_prefills = swa_metadata.num_prefills
num_decode_tokens = swa_metadata.num_decode_tokens
if num_prefills > 0:
self._forward_prefill(
q=q[num_decode_tokens:],
positions=positions[num_decode_tokens:],
compressed_k_cache=self_kv_cache,
swa_k_cache=swa_kv_cache,
output=output[num_decode_tokens:],
attn_metadata=flashmla_metadata,
swa_metadata=swa_metadata,
)
if num_decodes > 0:
self._forward_decode(
q=q[:num_decode_tokens],
kv_cache=self_kv_cache,
swa_metadata=swa_metadata,
attn_metadata=flashmla_metadata,
swa_only=swa_only,
output=output[:num_decode_tokens],
)
def _forward_decode(
self,
q: torch.Tensor,
kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1
swa_metadata: "DeepseekSparseSWAMetadata",
attn_metadata: FlashMLASparseMetadata | None,
swa_only: bool,
output: torch.Tensor,
) -> None:
num_decodes = swa_metadata.num_decodes
num_decode_tokens = swa_metadata.num_decode_tokens
topk_indices = None
topk_lens = None
if not swa_only:
assert attn_metadata is not None
assert swa_metadata.is_valid_token is not None
block_size = attn_metadata.block_size // self.compress_ratio
is_valid = swa_metadata.is_valid_token[:num_decode_tokens]
if self.compress_ratio == 4:
# C4A: local indices differ per layer (filled by Indexer).
assert self.topk_indices_buffer is not None
global_indices, topk_lens = compute_global_topk_indices_and_lens(
self.topk_indices_buffer[:num_decode_tokens],
swa_metadata.token_to_req_indices,
attn_metadata.block_table[:num_decodes],
block_size,
is_valid,
)
topk_indices = global_indices.view(num_decode_tokens, 1, -1)
else:
# C128A: pre-computed during metadata build.
topk_indices = attn_metadata.c128a_global_decode_topk_indices
topk_lens = attn_metadata.c128a_decode_topk_lens
swa_indices = swa_metadata.decode_swa_indices
swa_lens = swa_metadata.decode_swa_lens
# We treat queries in the same seq as different queries
# and later we only attend by generated indices.
# q arrives pre-padded to self.padded_heads by the outer wrapper.
q = q.unsqueeze(1)
# Prepare SWA cache (num_blocks, swa_block_size, 1, head_bytes)
# Use unsqueeze to preserve strides (handles padded blocks correctly)
swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2)
# Reshape KV cache to (num_blocks, block_size, 1, head_bytes)
if kv_cache is not None:
kv_cache = kv_cache.unsqueeze(-2)
# One FlashMLASchedMeta per layer type, shared across all same-type
# layers within this decode step. The first forward call per type
# triggers the in-kernel planner (allocating tile_scheduler_metadata
# and num_splits via PyTorch's graph-aware allocator so CUDA graph
# capture reuses the same addresses on replay); subsequent same-type
# layers see have_initialized=True and skip the planner.
if self.compress_ratio <= 1:
tile_metadata = swa_metadata.tile_sched_swaonly
elif self.compress_ratio == 4:
tile_metadata = swa_metadata.tile_sched_c4a
elif self.compress_ratio == 128:
tile_metadata = swa_metadata.tile_sched_c128a
else:
raise ValueError(
f"Unsupported compress_ratio={self.compress_ratio}; "
"expected 1, 4, or 128."
)
assert tile_metadata is not None, (
"swa_metadata missing tile_sched entry for "
f"compress_ratio={self.compress_ratio}; "
"DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not "
"allocate one for this layer type."
)
out, _ = flash_mla_with_kvcache(
q=q,
k_cache=swa_cache,
block_table=None,
head_dim_v=512,
tile_scheduler_metadata=tile_metadata,
cache_seqlens=None,
is_fp8_kvcache=True,
indices=swa_indices,
topk_length=swa_lens,
softmax_scale=self.scale,
attn_sink=self.attn_sink,
extra_k_cache=kv_cache if not swa_only else None,
extra_indices_in_kvcache=topk_indices,
extra_topk_length=topk_lens,
out=output.unsqueeze(1),
)
def _forward_prefill(
self,
q: torch.Tensor,
positions: torch.Tensor,
compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1
swa_k_cache: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashMLASparseMetadata | None,
swa_metadata: "DeepseekSparseSWAMetadata",
) -> None:
swa_only = attn_metadata is None
num_prefills = swa_metadata.num_prefills
num_prefill_tokens = swa_metadata.num_prefill_tokens
num_decodes = swa_metadata.num_decodes
num_decode_tokens = swa_metadata.num_decode_tokens
# Use pre-computed prefill metadata.
seq_lens = swa_metadata.prefill_seq_lens
gather_lens = swa_metadata.prefill_gather_lens
assert seq_lens is not None
assert gather_lens is not None
# Derive prefill-local token offsets from the full query_start_loc_cpu.
query_start_loc_cpu = swa_metadata.query_start_loc_cpu
query_start_loc = swa_metadata.query_start_loc
assert query_start_loc_cpu is not None
assert query_start_loc is not None
prefill_token_base = query_start_loc_cpu[num_decodes]
if not swa_only:
if self.compress_ratio == 4:
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
topk_indices = topk_indices[:num_prefill_tokens]
else:
# C128A: pre-computed during metadata build.
assert attn_metadata is not None
topk_indices = attn_metadata.c128a_prefill_topk_indices
top_k = topk_indices.shape[-1]
# Compressed region must fit the full compressed pool (seq_len //
# compress_ratio), not just top_k. top_k bounds how many indices
# the indexer selects, not the pool size it indexes into.
N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio
else:
# NOTE(woosuk): topk_indices will not be used for SWA-only layers.
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
top_k = 0
N = 0
M = N + self.window_size + self.max_num_batched_tokens
num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE
workspace_manager = current_workspace_manager()
kv = workspace_manager.get_simultaneous(
((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)[0]
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * PREFILL_CHUNK_SIZE
chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills)
chunk_size = chunk_end - chunk_start
if not swa_only:
# Gather compressed KV
assert attn_metadata is not None
block_table = attn_metadata.block_table[num_decodes:]
dequantize_and_gather_k_cache(
kv[:chunk_size],
compressed_k_cache,
seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio,
gather_lens=None,
block_table=block_table[chunk_start:chunk_end],
block_size=attn_metadata.block_size // self.compress_ratio,
offset=0,
)
# Gather SWA KV
swa_block_table = swa_metadata.block_table[num_decodes:]
dequantize_and_gather_k_cache(
kv[:chunk_size],
swa_k_cache,
seq_lens=seq_lens[chunk_start:chunk_end],
gather_lens=gather_lens[chunk_start:chunk_end],
block_table=swa_block_table[chunk_start:chunk_end],
block_size=swa_metadata.block_size,
offset=N,
)
# Combine the topk indices and SWA indices for gathered KV cache
query_start = (
query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base
)
query_end = (
query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base
)
combined_indices, combined_lens = combine_topk_swa_indices(
topk_indices[query_start:query_end],
query_start_loc[
num_decodes + chunk_start : num_decodes + chunk_end + 1
],
seq_lens[chunk_start:chunk_end],
gather_lens[chunk_start:chunk_end],
self.window_size,
self.compress_ratio,
top_k,
M,
N,
)
output_chunk, _, _ = flash_mla_sparse_fwd(
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
sm_scale=self.scale,
attn_sink=self.attn_sink,
topk_length=combined_lens,
out=output[query_start:query_end],
)
class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):
def __init__(
self,
head_dim: int,
dtype: torch.dtype,
prefix: str,
cache_config: CacheConfig,
compress_ratio: int = 1,
):
super().__init__()
self.kv_cache = torch.tensor([])
self.head_dim = head_dim
self.prefix = prefix
self.cache_config = cache_config
self.dtype = dtype
self.compress_ratio = compress_ratio
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# head_dim already carries the fp8 scale padding
# compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout.
return MLAAttentionSpec(
block_size=self.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
compress_ratio=self.compress_ratio,
# DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with
# the indexer's compressor state cache. V3.2 keeps the legacy layout.
alignment=576,
)
def forward(self): ...
def get_attn_backend(self) -> type[AttentionBackend]:
return DeepseekV4IndexerBackend
class DeepseekV4Indexer(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
config: DeepseekV2Config | DeepseekV3Config,
hidden_size: int,
q_lora_rank: int,
quant_config: QuantizationConfig | None,
cache_config: CacheConfig | None,
topk_indices_buffer: torch.Tensor | None,
compress_ratio: int = 1,
prefix: str = "",
):
super().__init__()
self.vllm_config = vllm_config
self.config = config
self.quant_config = quant_config
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
self.topk_tokens = config.index_topk
self.n_head = config.index_n_heads # 64
self.head_dim = config.index_head_dim # 128
self.rope_dim = config.qk_rope_head_dim # 64
self.q_lora_rank = q_lora_rank # 1536
self.compress_ratio = compress_ratio
self.use_fp4_kv = self.vllm_config.attention_config.use_fp4_indexer_cache
logger.info_once(
"Using %s indexer cache for Lighening Indexer.",
"MXFP4" if self.use_fp4_kv else "FP8",
)
# no tensor parallel, just replicated
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.head_dim * self.n_head,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
)
self.weights_proj = ReplicatedLinear(
hidden_size,
self.n_head,
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = "ue8m0"
self.quant_block_size = 128 # TODO: get from config
self.topk_indices_buffer = topk_indices_buffer
self.max_model_len = (
vllm_config.model_config.max_model_len // self.compress_ratio
)
self.prefix = prefix
self.max_total_seq_len = (
get_max_prefill_buffer_size(vllm_config) // self.compress_ratio
)
assert cache_config is not None, "Deepseek V4 indexer requires cache_config"
# NOTE(yifan): FP8 indxer cache use the same layout as V3.2:
# head_dim bytes = 128 fp8 + 4 fp32 scale = 132.
# For FP4 indexer cache, we still allocate the same amount of memory as FP8,
# but only use the first half of the memory.
k_cache_head_dim = self.head_dim + self.head_dim // self.quant_block_size * 4
self.k_cache = DeepseekV4IndexerCache(
head_dim=k_cache_head_dim,
dtype=torch.uint8,
prefix=f"{prefix}.k_cache",
cache_config=cache_config,
compress_ratio=self.compress_ratio,
)
self.compressor = DeepseekCompressor(
vllm_config=vllm_config,
compress_ratio=self.compress_ratio,
hidden_size=hidden_size,
head_dim=self.head_dim,
rotate=True,
prefix=f"{prefix}.compressor",
k_cache_prefix=self.k_cache.prefix,
use_fp4_cache=self.use_fp4_kv,
)
self.indexer_op = SparseAttnIndexer(
self.k_cache,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
skip_k_cache_insert=True,
use_fp4_cache=self.use_fp4_kv,
)
def forward(
self,
hidden_states: torch.Tensor,
qr: torch.Tensor,
positions: torch.Tensor,
rotary_emb: nn.Module,
) -> torch.Tensor:
q, _ = self.wq_b(qr)
q = q.view(-1, self.n_head, self.head_dim)
k = self.compressor(hidden_states, positions, rotary_emb)
weights, _ = self.weights_proj(hidden_states)
q_quant, weights = fused_indexer_q_rope_quant(
positions,
q,
rotary_emb.cos_sin_cache,
weights,
self.softmax_scale,
self.n_head**-0.5,
use_fp4=self.use_fp4_kv,
)
return self.indexer_op(hidden_states, q_quant, k, weights)
...@@ -117,8 +117,10 @@ class RoutingMethodType(IntEnum): ...@@ -117,8 +117,10 @@ class RoutingMethodType(IntEnum):
Custom = (6,) Custom = (6,)
# Simulated # Simulated
Simulated = (7,) Simulated = (7,)
# Deepseek V4 -> sqrtsoftplus + Bias + Normalize
DeepseekV4 = (8,)
# Unspecified # Unspecified
Unspecified = 8.0 Unspecified = 9.0
def get_routing_method_type( def get_routing_method_type(
...@@ -128,6 +130,14 @@ def get_routing_method_type( ...@@ -128,6 +130,14 @@ def get_routing_method_type(
num_expert_group: int | None, num_expert_group: int | None,
has_e_score_bias: bool, has_e_score_bias: bool,
) -> RoutingMethodType: ) -> RoutingMethodType:
if scoring_func == "sqrtsoftplus":
# DeepSeek V4 uses sqrtsoftplus routing with optional routing bias
# and top-k renormalization.
if renormalize:
return RoutingMethodType.DeepseekV4
else:
return RoutingMethodType.Unspecified
if has_e_score_bias: if has_e_score_bias:
if (num_expert_group or 0) > 0 and scoring_func == "sigmoid": if (num_expert_group or 0) > 0 and scoring_func == "sigmoid":
return RoutingMethodType.DeepSeekV3 return RoutingMethodType.DeepSeekV3
...@@ -230,6 +240,13 @@ class FusedMoEQuantConfig: ...@@ -230,6 +240,13 @@ class FusedMoEQuantConfig:
_w2: FusedMoEQuantDesc _w2: FusedMoEQuantDesc
is_nvfp4_scale_swizzled: bool = True is_nvfp4_scale_swizzled: bool = True
# MXFP4-specific TRTLLM parameters for SwiGLU activation clamping.
# These correspond to gemm1_alpha, gemm1_beta, gemm1_clamp_limit
# in TrtLlmMxfp4ExpertsBase.
gemm1_alpha: float | None = None
gemm1_beta: float | None = None
gemm1_clamp_limit: float | None = None
def __post_init__(self): def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, ( assert not self.per_act_token_quant or self.block_shape is None, (
"illegal quantization" "illegal quantization"
...@@ -477,6 +494,9 @@ class FusedMoEQuantConfig: ...@@ -477,6 +494,9 @@ class FusedMoEQuantConfig:
w2_zp: torch.Tensor | None = None, w2_zp: torch.Tensor | None = None,
weight_dtype: torch.dtype | str | None = None, weight_dtype: torch.dtype | str | None = None,
is_nvfp4_scale_swizzled: bool = True, is_nvfp4_scale_swizzled: bool = True,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> "FusedMoEQuantConfig": ) -> "FusedMoEQuantConfig":
""" """
General builder function for a FusedMoEQuantConfig. General builder function for a FusedMoEQuantConfig.
...@@ -507,6 +527,9 @@ class FusedMoEQuantConfig: ...@@ -507,6 +527,9 @@ class FusedMoEQuantConfig:
- w1_zp: Optional w1 zero points for int4/int8 quantization. - w1_zp: Optional w1 zero points for int4/int8 quantization.
- w2_zp: Optional w2 zero points for int4/int8 quantization. - w2_zp: Optional w2 zero points for int4/int8 quantization.
- is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling. - is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling.
- gemm1_alpha: Optional MXFP4 TRTLLM SwiGLU alpha parameter.
- gemm1_beta: Optional MXFP4 TRTLLM SwiGLU beta parameter.
- gemm1_clamp_limit: Optional MXFP4 TRTLLM SwiGLU clamp limit.
""" """
assert not isinstance(quant_dtype, str) or quant_dtype in { assert not isinstance(quant_dtype, str) or quant_dtype in {
"nvfp4", "nvfp4",
...@@ -540,6 +563,9 @@ class FusedMoEQuantConfig: ...@@ -540,6 +563,9 @@ class FusedMoEQuantConfig:
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
), ),
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled, is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
) )
assert quant_config.per_act_token_quant == per_act_token_quant assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant assert quant_config.per_out_ch_quant == per_out_ch_quant
...@@ -650,6 +676,9 @@ def mxfp4_w4a16_moe_quant_config( ...@@ -650,6 +676,9 @@ def mxfp4_w4a16_moe_quant_config(
w2_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> FusedMoEQuantConfig: ) -> FusedMoEQuantConfig:
""" """
Construct a quant config for unquantized activations and mxfp4 weights. Construct a quant config for unquantized activations and mxfp4 weights.
...@@ -659,6 +688,9 @@ def mxfp4_w4a16_moe_quant_config( ...@@ -659,6 +688,9 @@ def mxfp4_w4a16_moe_quant_config(
_a2=FusedMoEQuantDesc(), _a2=FusedMoEQuantDesc(),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
) )
...@@ -670,6 +702,9 @@ def mxfp4_mxfp8_moe_quant_config( ...@@ -670,6 +702,9 @@ def mxfp4_mxfp8_moe_quant_config(
w1_bias: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None, block_shape: list[int] | None = None,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> FusedMoEQuantConfig: ) -> FusedMoEQuantConfig:
""" """
Construct a quant config for mxfp4 activations and mxfp4 weights. Construct a quant config for mxfp4 activations and mxfp4 weights.
...@@ -679,6 +714,9 @@ def mxfp4_mxfp8_moe_quant_config( ...@@ -679,6 +714,9 @@ def mxfp4_mxfp8_moe_quant_config(
_a2=FusedMoEQuantDesc("mxfp8"), _a2=FusedMoEQuantDesc("mxfp8"),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
) )
...@@ -712,6 +750,9 @@ def ocp_mx_moe_quant_config( ...@@ -712,6 +750,9 @@ def ocp_mx_moe_quant_config(
w1_bias: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None, block_shape: list[int] | None = None,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> FusedMoEQuantConfig: ) -> FusedMoEQuantConfig:
""" """
Construct a quant config for mxfp4 activations and mxfp4 weights. Construct a quant config for mxfp4 activations and mxfp4 weights.
...@@ -729,6 +770,9 @@ def ocp_mx_moe_quant_config( ...@@ -729,6 +770,9 @@ def ocp_mx_moe_quant_config(
per_act_token_quant=False, per_act_token_quant=False,
per_out_ch_quant=False, per_out_ch_quant=False,
block_shape=block_shape, block_shape=block_shape,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
) )
......
...@@ -25,15 +25,20 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -25,15 +25,20 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8_packed_for_deepgemm, per_token_group_quant_fp8_packed_for_deepgemm,
silu_mul_per_token_group_quant_fp8_colmajor, silu_mul_per_token_group_quant_fp8_colmajor,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
silu_mul_quant_fp8_packed_triton as fused_silu_mul_fp8_quant_packed,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8Dynamic128Sym, kFp8Dynamic128Sym,
kFp8Static128BlockSym, kFp8Static128BlockSym,
kMxfp4Static,
) )
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT, DeepGemmQuantScaleFMT,
get_mk_alignment_for_contiguous_layout, get_mk_alignment_for_contiguous_layout,
is_deep_gemm_supported, is_deep_gemm_supported,
m_grouped_fp8_fp4_gemm_nt_contiguous,
m_grouped_fp8_gemm_nt_contiguous, m_grouped_fp8_gemm_nt_contiguous,
) )
from vllm.utils.import_utils import has_deep_gemm from vllm.utils.import_utils import has_deep_gemm
...@@ -197,8 +202,14 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular): ...@@ -197,8 +202,14 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
M_sum, N = input.size() M_sum, N = input.size()
activation_out_dim = self.adjust_N_for_activation(N, activation) activation_out_dim = self.adjust_N_for_activation(N, activation)
# 1. DeepGemm UE8M0: use packed per-token-group quant # 1. DeepGemm UE8M0: fused SiLU+mul+clamp+quant+pack
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0: if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
if activation == MoEActivation.SILU:
return fused_silu_mul_fp8_quant_packed(
input=input,
output_q=output,
group_size=block_k,
)
act_out = torch.empty( act_out = torch.empty(
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device (M_sum, activation_out_dim), dtype=input.dtype, device=input.device
) )
...@@ -312,3 +323,225 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular): ...@@ -312,3 +323,225 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
expert_map=expert_map, expert_map=expert_map,
output=output, output=output,
) )
class DeepGemmFP4Experts(mk.FusedMoEExpertsModular):
"""DeepGemm-based fused MoE expert implementation for FP4 weights.
Uses m_grouped_fp8_fp4_gemm_nt_contiguous with FP8 activations and
MXFP4 (FP4 E2M1 packed as uint8) weights. Requires SM100+ (Blackwell).
"""
# FP8 activation block size (hardcoded since mxfp4_w4a8 quant config
# does not set a block_shape on the activation descriptor).
_ACT_BLOCK_K = 128
# FP4 weight block size
_WEIGHT_BLOCK_K = 32
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__(moe_config=moe_config, quant_config=quant_config)
assert quant_config.weight_quant_dtype == "mxfp4"
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
from vllm.platforms import current_platform
return (
is_deep_gemm_supported()
and current_platform.is_device_capability_family(100)
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kMxfp4Static, kFp8Dynamic128Sym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.SWIGLUSTEP]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not (
moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
block_m = get_mk_alignment_for_contiguous_layout()[0]
M_sum = compute_aligned_M(
M, topk, local_num_experts, block_m, expert_tokens_meta
)
assert M_sum % block_m == 0
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M_sum, max(activation_out_dim, K))
workspace2 = (M_sum, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
def _act_mul_quant(
self, input: torch.Tensor, output: torch.Tensor, activation: MoEActivation
) -> tuple[torch.Tensor, torch.Tensor]:
block_k = self._ACT_BLOCK_K
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
M_sum, N = input.size()
activation_out_dim = self.adjust_N_for_activation(N, activation)
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
assert activation == MoEActivation.SILU
return fused_silu_mul_fp8_quant_packed(
input=input,
output_q=output,
group_size=block_k,
clamp_limit=self.gemm1_clamp_limit,
)
if activation == MoEActivation.SILU:
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
return silu_mul_per_token_group_quant_fp8_colmajor(
input=input,
output=output,
use_ue8m0=use_ue8m0,
)
act_out = torch.empty(
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
return per_token_group_quant_fp8(
act_out, block_k, column_major_scales=True, out_q=output
)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert a1q_scale is not None
assert a2_scale is None
assert self.w1_scale is not None
assert self.w2_scale is not None
a1q = hidden_states
_, N, _ = w1.size()
# K comes from activations (full hidden dim), not from w1 which is
# packed FP4 (E, N, K//2).
K = a1q.size(1)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
M_sum = compute_aligned_M(
M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=get_mk_alignment_for_contiguous_layout()[0],
expert_tokens_meta=expert_tokens_meta,
)
a1q_perm = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
)
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1q_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
expert_tokens_meta=expert_tokens_meta,
aq_out=a1q_perm,
)
assert a1q.size(0) == M_sum
# FC1: FP8 activations x FP4 weights
# DeepGEMM 2.4.2 requires FP4-packed weights as int8 (kPackedFP4).
mm1_out = _resize_cache(workspace2, (M_sum, N))
m_grouped_fp8_fp4_gemm_nt_contiguous(
(a1q, a1q_scale),
(w1.view(torch.int8), self.w1_scale),
mm1_out,
expert_ids,
recipe_a=(1, self._ACT_BLOCK_K),
recipe_b=(1, self._WEIGHT_BLOCK_K),
)
# SwiGLU activation + FP8 requant
activation_out_dim = self.adjust_N_for_activation(N, activation)
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
)
a2q, a2q_scale = self._act_mul_quant(
input=mm1_out.view(-1, N), output=quant_out, activation=activation
)
# FC2: FP8 activations x FP4 weights
mm2_out = _resize_cache(workspace2, (M_sum, K))
m_grouped_fp8_fp4_gemm_nt_contiguous(
(a2q, a2q_scale),
(w2.view(torch.int8), self.w2_scale),
mm2_out,
expert_ids,
recipe_a=(1, self._ACT_BLOCK_K),
recipe_b=(1, self._WEIGHT_BLOCK_K),
)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
deepgemm_unpermute_and_reduce(
a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=output,
)
...@@ -28,8 +28,162 @@ from vllm.platforms import current_platform ...@@ -28,8 +28,162 @@ from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
from ..utils import swiglu_limit_func
logger = init_logger(__name__) logger = init_logger(__name__)
def _patch_make_bitmatrix_metadata() -> None:
"""Monkey-patch make_bitmatrix_metadata to support non-power-of-2 top_k.
triton's tl.arange requires a power-of-2 range. The original kernel
computes BLOCK_SIZE = BLOCK_PER_TOK * TOKS_PER_ROW (= 32 * top_k). For
DeepSeek-V4 with top_k=6 this gives 192, which is not a power of 2 and
causes a compile error at the first forward pass.
Fix: define a drop-in replacement kernel that accepts an extra constexpr
BLOCK_SIZE_PADDED (next power of 2 >= BLOCK_SIZE) and uses it for the
tl.arange call while keeping the actual BLOCK_SIZE as the stride between
thread-blocks so that all flat indices into NonzeroIndx stay correct.
Elements beyond BLOCK_SIZE are masked out (col_indx = 0xffff) and ignored.
This function is called once at module load time and patches the function
inside the triton_kernels tensor module so that SparseMatrix.__post_init__
picks up the fixed version transparently.
"""
import torch
import triton
import triton.language as tl
try:
from vllm.third_party.triton_kernels.tensor_details import (
bitmatrix as _bm,
)
from vllm.third_party.triton_kernels.tensor_details.bitmatrix import (
BitmatrixMetadata,
_keyed_add,
cdiv,
)
from vllm.third_party.triton_kernels.tensor_details.bitmatrix_details.sum_bitmatrix_rows import ( # noqa: E501
sum_bitmatrix_rows,
)
except ImportError:
return
@triton.jit
def _stage2_pow2(
ColSortedIndx,
RowSortedIndx,
NonzeroIndx,
n_tokens,
ColPartialSum,
stride_pm,
stride_pn,
ColOffs,
TOKS_PER_ROW: tl.constexpr,
BLOCK_PER_TOK: tl.constexpr,
BLOCK_SIZE_PADDED: tl.constexpr,
):
# Actual number of elements per block (may not be a power of 2).
BLOCK_SIZE: tl.constexpr = BLOCK_PER_TOK * TOKS_PER_ROW
tl.static_assert(BLOCK_SIZE_PADDED <= 32768)
if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr():
n_tokens = tl.load(n_tokens)
nonzero_indx_size = n_tokens * TOKS_PER_ROW
pid_m = tl.program_id(0)
# Use BLOCK_SIZE_PADDED (a power of 2) for tl.arange, but stride by
# the actual BLOCK_SIZE so flat positions in NonzeroIndx are correct.
# Elements with offs_local >= BLOCK_SIZE have offs_global beyond the
# valid range, get col_indx = 0xffff, and are filtered by the mask
# below without producing any output.
offs_local = tl.arange(0, BLOCK_SIZE_PADDED)
offs_global = pid_m * BLOCK_SIZE + offs_local
mask = offs_global < nonzero_indx_size
col_indx = tl.load(NonzeroIndx + offs_global, mask=mask, other=-1).to(tl.uint32)
kv_pairs = ((col_indx << 16) | offs_local).to(tl.uint32)
kv_pairs = tl.sort(kv_pairs, 0)
col_indx = kv_pairs >> 16
offs_global = pid_m * BLOCK_SIZE + (kv_pairs & 0xFFFF)
mask = col_indx != 0xFFFF
x = kv_pairs & 0xFFFF0000 | 0x00000001
cols_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
exclusive_run_lengths = (cols_and_inclusive_run_lengths - 1) & 0xFFFF
row_sorted_indx = tl.load(
ColPartialSum + pid_m * stride_pm + col_indx * stride_pn, mask=mask
)
row_sorted_indx += tl.load(ColOffs + col_indx, mask=mask)
row_sorted_indx += exclusive_run_lengths
tl.store(RowSortedIndx + offs_global, row_sorted_indx, mask=mask)
tl.store(ColSortedIndx + row_sorted_indx, offs_global, mask=mask)
def _make_bitmatrix_metadata_pow2_safe(nonzero_indx, bitmatrix):
assert nonzero_indx.ndim == 2
PARTIAL_BLOCK_M = 32
col_sum, col_partial_sum = sum_bitmatrix_rows(
bitmatrix, partials_block_size=PARTIAL_BLOCK_M
)
device = bitmatrix.device
n_indx = nonzero_indx.numel()
n_cols = bitmatrix.shape[1]
col_offs = torch.empty(n_cols, dtype=torch.int32, device=device)
combined_indx = torch.empty(n_indx * 2, dtype=torch.int32, device=device)
col_sorted_indx = combined_indx[:n_indx]
row_sorted_indx = combined_indx[n_indx:]
MEMSET_BLOCK = 1024
memset_grid = (cdiv(n_indx * 2, MEMSET_BLOCK) + n_cols + 1,)
_bm._bitmatrix_metadata_compute_stage1[memset_grid](
combined_indx,
n_indx * 2,
-1,
MEMSET_BLOCK,
col_sum,
col_offs,
col_sum.shape[0],
col_partial_sum,
col_partial_sum.shape[0],
col_partial_sum.stride(0),
col_partial_sum.stride(1),
BLOCK_M=512,
BLOCK_N=512,
)
toks_per_row = nonzero_indx.shape[-1]
block_size = PARTIAL_BLOCK_M * toks_per_row
# Next power of 2 >= block_size (required by tl.arange).
block_size_padded = 1 << (max(block_size, 1) - 1).bit_length()
compute_grid = (cdiv(bitmatrix.shape_max[0], PARTIAL_BLOCK_M),)
_stage2_pow2[compute_grid](
col_sorted_indx,
row_sorted_indx,
nonzero_indx,
bitmatrix.shape[0],
col_partial_sum,
col_partial_sum.stride(0),
col_partial_sum.stride(1),
col_offs,
TOKS_PER_ROW=toks_per_row,
BLOCK_PER_TOK=PARTIAL_BLOCK_M,
BLOCK_SIZE_PADDED=block_size_padded,
)
return BitmatrixMetadata(
col_sum=col_sum,
col_sorted_indx=col_sorted_indx,
row_sorted_indx=row_sorted_indx,
)
# The most reliable patch point: SparseMatrix.__post_init__ looks up
# make_bitmatrix_metadata via its own __globals__ dict (the tensor.py
# module dict). Patching through __globals__ works regardless of how
# sys.modules maps "triton_kernels.tensor" vs
# "vllm.third_party.triton_kernels.tensor".
from triton_kernels.tensor import SparseMatrix as _SparseMatrix
_SparseMatrix.__post_init__.__globals__["make_bitmatrix_metadata"] = (
_make_bitmatrix_metadata_pow2_safe
)
# Also patch the bitmatrix module itself in case it is imported directly.
_bm.make_bitmatrix_metadata = _make_bitmatrix_metadata_pow2_safe
use_legacy_triton_kernels = False use_legacy_triton_kernels = False
if has_triton_kernels(): if has_triton_kernels():
...@@ -59,6 +213,8 @@ if has_triton_kernels(): ...@@ -59,6 +213,8 @@ if has_triton_kernels():
use_legacy_triton_kernels = True use_legacy_triton_kernels = True
else: else:
raise raise
if not use_legacy_triton_kernels:
_patch_make_bitmatrix_metadata()
except (AttributeError, ImportError) as e: except (AttributeError, ImportError) as e:
logger.error( logger.error(
"Failed to import Triton kernels. Please make sure your triton " "Failed to import Triton kernels. Please make sure your triton "
...@@ -497,6 +653,8 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular): ...@@ -497,6 +653,8 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
return False return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell) # (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5). # and ROCm gfx942/gfx950 (which map to 9.4/9.5).
if not has_triton_kernels():
return False
return (9, 0) <= (cap.major, cap.minor) < (11, 0) return (9, 0) <= (cap.major, cap.minor) < (11, 0)
@staticmethod @staticmethod
...@@ -698,6 +856,37 @@ class UnfusedOAITritonExperts(LoRAExpertsMixin, BaseOAITritonExperts): ...@@ -698,6 +856,37 @@ class UnfusedOAITritonExperts(LoRAExpertsMixin, BaseOAITritonExperts):
def moe_sum(self, input: torch.Tensor, output: torch.Tensor): def moe_sum(self, input: torch.Tensor, output: torch.Tensor):
ops.moe_sum(input, output) ops.moe_sum(input, output)
def activation(
self,
activation: MoEActivation,
output: torch.Tensor,
input: torch.Tensor,
) -> None:
quant_config = self.quant_config or FUSED_MOE_UNQUANTIZED_CONFIG
if activation == MoEActivation.SWIGLUOAI:
alpha = (
quant_config.gemm1_alpha
if quant_config.gemm1_alpha is not None
else 1.702
)
limit = (
quant_config.gemm1_clamp_limit
if quant_config.gemm1_clamp_limit is not None
else 7.0
)
torch.ops._C.swigluoai_and_mul(output, input, alpha, limit)
elif (
activation == MoEActivation.SILU
and quant_config.gemm1_clamp_limit is not None
):
swiglu_limit_func(
output,
input,
quant_config.gemm1_clamp_limit,
)
else:
super().activation(activation, output, input)
def apply( def apply(
self, self,
output: torch.Tensor, output: torch.Tensor,
...@@ -812,9 +1001,9 @@ class UnfusedOAITritonExperts(LoRAExpertsMixin, BaseOAITritonExperts): ...@@ -812,9 +1001,9 @@ class UnfusedOAITritonExperts(LoRAExpertsMixin, BaseOAITritonExperts):
act_input, act_input,
) )
# matmul_ogs grouped reduction fuse sum across multiple experts: # matmul_ogs grouped reduction fuses sum across multiple experts:
# y[dst_indx // n_expts_act, :] += x # y[dst_indx // n_expts_act, :] += x
# Need to set n_expts_act to 1 to unfuse moe_sum # Set n_expts_act to 1 to unfuse the sum so we can do it manually via moe_sum.
routing_data.n_expts_act = 1 routing_data.n_expts_act = 1
matmul_ogs( matmul_ogs(
...@@ -878,6 +1067,8 @@ class OAITritonMxfp4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic): ...@@ -878,6 +1067,8 @@ class OAITritonMxfp4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
return False return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell) # (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5). # and ROCm gfx942/gfx950 (which map to 9.4/9.5).
if not has_triton_kernels():
return False
return (9, 0) <= (cap.major, cap.minor) < (11, 0) return (9, 0) <= (cap.major, cap.minor) < (11, 0)
@staticmethod @staticmethod
......
...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kMxfp4Static, kMxfp4Static,
...@@ -32,10 +33,8 @@ class TrtLlmMxfp4ExpertsBase: ...@@ -32,10 +33,8 @@ class TrtLlmMxfp4ExpertsBase:
self, self,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
**kwargs,
): ):
# NOTE: FusedMoEExperts.__init__ is called by the concrete subclass
# (Monolithic/Modular) via MRO, not here, to avoid mypy issues with
# multiple inheritance. This matches the NvFP4 expert pattern.
self.moe_config = moe_config self.moe_config = moe_config
self.quant_config = quant_config self.quant_config = quant_config
...@@ -48,23 +47,34 @@ class TrtLlmMxfp4ExpertsBase: ...@@ -48,23 +47,34 @@ class TrtLlmMxfp4ExpertsBase:
self.local_num_experts = moe_config.num_local_experts self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank self.ep_rank = moe_config.moe_parallel_config.ep_rank
# MXFP4-specific TRTLLM parameters # MXFP4-specific TRTLLM parameters from quant_config
device = torch.accelerator.current_device_index() device = torch.accelerator.current_device_index()
self.gemm1_alpha = torch.tensor( if quant_config.gemm1_alpha is not None:
[1.702] * self.local_num_experts, self.gemm1_alpha = torch.tensor(
dtype=torch.float32, [quant_config.gemm1_alpha] * self.local_num_experts,
device=device, dtype=torch.float32,
) device=device,
self.gemm1_beta = torch.tensor( )
[1.0] * self.local_num_experts, else:
dtype=torch.float32, self.gemm1_alpha = None
device=device,
) if quant_config.gemm1_beta is not None:
self.gemm1_clamp_limit = torch.tensor( self.gemm1_beta = torch.tensor(
[7.0] * self.local_num_experts, [quant_config.gemm1_beta] * self.local_num_experts,
dtype=torch.float32, dtype=torch.float32,
device=device, device=device,
) )
else:
self.gemm1_beta = None
if quant_config.gemm1_clamp_limit is not None:
self.gemm1_clamp_limit = torch.tensor(
[quant_config.gemm1_clamp_limit] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
else:
self.gemm1_clamp_limit = None
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
...@@ -97,7 +107,7 @@ class TrtLlmMxfp4ExpertsBase: ...@@ -97,7 +107,7 @@ class TrtLlmMxfp4ExpertsBase:
@staticmethod @staticmethod
def _supports_activation(activation: MoEActivation) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SWIGLUOAI return activation in (MoEActivation.SWIGLUOAI, MoEActivation.SILU)
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
...@@ -190,36 +200,41 @@ class TrtLlmMxfp4ExpertsMonolithic( ...@@ -190,36 +200,41 @@ class TrtLlmMxfp4ExpertsMonolithic(
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
return trtllm_fp4_block_scale_moe( from vllm.utils.flashinfer import _is_fi_autotuning, autotune
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None, with autotune(_is_fi_autotuning):
hidden_states=x_quant, trtllm_fp4_block_scale_moe(
hidden_states_scale=x_scale, routing_logits=router_logits.to(torch.bfloat16),
gemm1_weights=w1, routing_bias=None,
gemm1_weights_scale=self.w1_scale, hidden_states=x_quant,
gemm1_bias=self.w1_bias, hidden_states_scale=x_scale,
gemm1_alpha=self.gemm1_alpha, gemm1_weights=w1,
gemm1_beta=self.gemm1_beta, gemm1_weights_scale=self.w1_scale,
gemm1_clamp_limit=self.gemm1_clamp_limit, gemm1_bias=self.w1_bias,
gemm2_weights=w2, gemm1_alpha=self.gemm1_alpha,
gemm2_weights_scale=self.w2_scale, gemm1_beta=self.gemm1_beta,
gemm2_bias=self.w2_bias, gemm1_clamp_limit=self.gemm1_clamp_limit,
output1_scale_scalar=None, gemm2_weights=w2,
output1_scale_gate_scalar=None, gemm2_weights_scale=self.w2_scale,
output2_scale_scalar=None, gemm2_bias=self.w2_bias,
num_experts=global_num_experts, output1_scale_scalar=None,
top_k=self.topk, output1_scale_gate_scalar=None,
n_group=None, output2_scale_scalar=None,
topk_group=None, num_experts=global_num_experts,
intermediate_size=self.intermediate_size_per_partition, top_k=self.topk,
local_expert_offset=self.ep_rank * self.local_num_experts, n_group=None,
local_num_experts=self.local_num_experts, topk_group=None,
routed_scaling_factor=None, intermediate_size=self.intermediate_size_per_partition,
routing_method_type=self.routing_method_type, local_expert_offset=self.ep_rank * self.local_num_experts,
do_finalize=True, local_num_experts=self.local_num_experts,
tune_max_num_tokens=max(self.max_capture_size, 1), routed_scaling_factor=None,
output=output, routing_method_type=self.routing_method_type,
)[0] do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1),
output=output,
)
return output
class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModular): class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModular):
...@@ -239,6 +254,16 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula ...@@ -239,6 +254,16 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
) -> bool: ) -> bool:
return True return True
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# Modular kernel handles only the expert computation;
# routing is done externally, so accept any routing method.
return True
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
...@@ -282,7 +307,7 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula ...@@ -282,7 +307,7 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
): ):
topk = topk_ids.size(-1) topk = topk_ids.size(-1)
local_num_experts = w1.size(0) local_num_experts = w1.size(0)
intermediate_size = w2.size(1) intermediate_size = self.intermediate_size_per_partition
local_expert_offset = self.moe_config.ep_rank * local_num_experts local_expert_offset = self.moe_config.ep_rank * local_num_experts
# Handle input quantization # Handle input quantization
...@@ -302,9 +327,8 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula ...@@ -302,9 +327,8 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
x_quant = hidden_states x_quant = hidden_states
x_scale = None x_scale = None
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( # Pack topk ids and weights into format expected by the kernel.
torch.bfloat16 packed_tensor = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
).view(torch.int16)
assert self.w1_scale is not None assert self.w1_scale is not None
assert self.w2_scale is not None assert self.w2_scale is not None
...@@ -333,7 +357,10 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula ...@@ -333,7 +357,10 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
"local_expert_offset": local_expert_offset, "local_expert_offset": local_expert_offset,
"local_num_experts": local_num_experts, "local_num_experts": local_num_experts,
"routed_scaling_factor": None, "routed_scaling_factor": None,
"routing_method_type": self.routing_method_type, # Modular kernel receives pre-routed tokens, so routing
# is already done. Use Renormalize as a safe default that
# the TRTLLM C++ kernel supports.
"routing_method_type": RoutingMethodType.Renormalize,
"do_finalize": True, "do_finalize": True,
"output": output, "output": output,
"tune_max_num_tokens": max(self.max_capture_size, 1), "tune_max_num_tokens": max(self.max_capture_size, 1),
...@@ -341,12 +368,9 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula ...@@ -341,12 +368,9 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
from flashinfer import trtllm_fp4_block_scale_routed_moe from flashinfer import trtllm_fp4_block_scale_routed_moe
from vllm.utils.flashinfer import autotune from vllm.utils.flashinfer import _is_fi_autotuning, autotune
with autotune(False): with autotune(_is_fi_autotuning):
# Enable autotune when,
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is
# resolved.
trtllm_fp4_block_scale_routed_moe(**kwargs) trtllm_fp4_block_scale_routed_moe(**kwargs)
return output return output
...@@ -50,6 +50,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -50,6 +50,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from .utils import swiglu_limit_func
def _fused_marlin_moe( def _fused_marlin_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -88,6 +90,7 @@ def _fused_marlin_moe( ...@@ -88,6 +90,7 @@ def _fused_marlin_moe(
output: torch.Tensor | None = None, output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None, input_dtype: torch.dtype | None = None,
is_k_full: bool = True, is_k_full: bool = True,
clamp_limit: float | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert hidden_states.ndim == 2 assert hidden_states.ndim == 2
M, K = hidden_states.size() M, K = hidden_states.size()
...@@ -155,11 +158,18 @@ def _fused_marlin_moe( ...@@ -155,11 +158,18 @@ def _fused_marlin_moe(
use_fp32_reduce=True, use_fp32_reduce=True,
is_zp_float=False, is_zp_float=False,
) )
activation_func( if clamp_limit is not None and activation == MoEActivation.SILU:
activation, swiglu_limit_func(
intermediate_cache2, intermediate_cache2,
intermediate_cache1.view(-1, w13_num_shards * N), intermediate_cache1.view(-1, w13_num_shards * N),
) clamp_limit,
)
else:
activation_func(
activation,
intermediate_cache2,
intermediate_cache1.view(-1, w13_num_shards * N),
)
if output is None: if output is None:
output = intermediate_cache3 output = intermediate_cache3
...@@ -247,6 +257,7 @@ def fused_marlin_moe( ...@@ -247,6 +257,7 @@ def fused_marlin_moe(
output: torch.Tensor | None = None, output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None, input_dtype: torch.dtype | None = None,
inplace: bool = False, inplace: bool = False,
clamp_limit: float | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -363,6 +374,7 @@ def fused_marlin_moe( ...@@ -363,6 +374,7 @@ def fused_marlin_moe(
output=None, output=None,
input_dtype=input_dtype, input_dtype=input_dtype,
is_k_full=is_k_full, is_k_full=is_k_full,
clamp_limit=clamp_limit,
).view(-1, topk, K) ).view(-1, topk, K)
if output is None: if output is None:
...@@ -557,6 +569,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular): ...@@ -557,6 +569,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
self.is_k_full = is_k_full self.is_k_full = is_k_full
self.input_dtype = get_marlin_input_dtype() self.input_dtype = get_marlin_input_dtype()
self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit
super().__init__( super().__init__(
moe_config=moe_config, moe_config=moe_config,
...@@ -850,6 +863,7 @@ class MarlinExperts(LoRAExpertsMixin, MarlinExpertsBase): ...@@ -850,6 +863,7 @@ class MarlinExperts(LoRAExpertsMixin, MarlinExpertsBase):
sort_indices2=self.w2_g_idx_sort_indices, sort_indices2=self.w2_g_idx_sort_indices,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
input_dtype=self.input_dtype, input_dtype=self.input_dtype,
clamp_limit=self.gemm1_clamp_limit,
) )
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
......
...@@ -169,5 +169,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -169,5 +169,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -268,6 +268,7 @@ class FusedMoE(PluggableLayer): ...@@ -268,6 +268,7 @@ class FusedMoE(PluggableLayer):
custom_routing_function: Callable | None = None, custom_routing_function: Callable | None = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
swiglu_limit: float | None = None,
e_score_correction_bias: torch.Tensor | None = None, e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
...@@ -285,6 +286,7 @@ class FusedMoE(PluggableLayer): ...@@ -285,6 +286,7 @@ class FusedMoE(PluggableLayer):
routed_output_transform: torch.nn.Module | None = None, routed_output_transform: torch.nn.Module | None = None,
apply_routed_scale_to_output: bool = False, apply_routed_scale_to_output: bool = False,
zero_expert_type: str | None = None, zero_expert_type: str | None = None,
hash_indices_table: torch.Tensor | None = None,
): ):
super().__init__() super().__init__()
...@@ -294,6 +296,7 @@ class FusedMoE(PluggableLayer): ...@@ -294,6 +296,7 @@ class FusedMoE(PluggableLayer):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.swiglu_limit = swiglu_limit
# FIXME (varun): We should have a better way of inferring the activation # FIXME (varun): We should have a better way of inferring the activation
# datatype. This works for now as the tensor datatype entering the MoE # datatype. This works for now as the tensor datatype entering the MoE
...@@ -455,6 +458,7 @@ class FusedMoE(PluggableLayer): ...@@ -455,6 +458,7 @@ class FusedMoE(PluggableLayer):
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
# TODO(bnell): end attributes # TODO(bnell): end attributes
self.hash_indices_table = hash_indices_table
self.apply_router_weight_on_input = apply_router_weight_on_input self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = MoEActivation.from_str(activation) self.activation = MoEActivation.from_str(activation)
...@@ -479,6 +483,7 @@ class FusedMoE(PluggableLayer): ...@@ -479,6 +483,7 @@ class FusedMoE(PluggableLayer):
indices_type_getter=lambda: self.quant_method.topk_indices_dtype, indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
zero_expert_type=zero_expert_type, zero_expert_type=zero_expert_type,
num_logical_experts=self.logical_num_experts, num_logical_experts=self.logical_num_experts,
hash_indices_table=self.hash_indices_table,
) )
self.routing_method_type: RoutingMethodType = self.router.routing_method_type self.routing_method_type: RoutingMethodType = self.router.routing_method_type
...@@ -1541,10 +1546,12 @@ class FusedMoE(PluggableLayer): ...@@ -1541,10 +1546,12 @@ class FusedMoE(PluggableLayer):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.runner.forward( return self.runner.forward(
hidden_states, hidden_states,
router_logits, router_logits,
input_ids,
) )
@property @property
......
...@@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import ( ...@@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
FusedMoEQuantDesc,
mxfp4_mxfp8_moe_quant_config, mxfp4_mxfp8_moe_quant_config,
mxfp4_w4a16_moe_quant_config, mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config, ocp_mx_moe_quant_config,
...@@ -24,6 +25,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -24,6 +25,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8Dynamic128Sym,
kMxfp4Static, kMxfp4Static,
kMxfp8Dynamic, kMxfp8Dynamic,
) )
...@@ -46,6 +48,8 @@ if has_triton_kernels(): ...@@ -46,6 +48,8 @@ if has_triton_kernels():
class Mxfp4MoeBackend(Enum): class Mxfp4MoeBackend(Enum):
NONE = "None" NONE = "None"
# DeepGEMM FP8xFP4 backend (SM100+)
DEEPGEMM_MXFP4 = "DEEPGEMM_MXFP4"
# FlashInfer TRTLLM backends # FlashInfer TRTLLM backends
FLASHINFER_TRTLLM_MXFP4_MXFP8 = "FLASHINFER_TRTLLM_MXFP4_MXFP8" FLASHINFER_TRTLLM_MXFP4_MXFP8 = "FLASHINFER_TRTLLM_MXFP4_MXFP8"
FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_TRTLLM_MXFP4_BF16" FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_TRTLLM_MXFP4_BF16"
...@@ -81,7 +85,14 @@ TRITON_BACKENDS = ( ...@@ -81,7 +85,14 @@ TRITON_BACKENDS = (
def backend_to_kernel_cls( def backend_to_kernel_cls(
backend: Mxfp4MoeBackend, backend: Mxfp4MoeBackend,
) -> list[type[mk.FusedMoEExperts]]: ) -> list[type[mk.FusedMoEExperts]]:
if backend in ( if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
from vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe import (
DeepGemmFP4Experts,
)
return [DeepGemmFP4Experts]
elif backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
): ):
...@@ -159,11 +170,13 @@ def backend_to_kernel_cls( ...@@ -159,11 +170,13 @@ def backend_to_kernel_cls(
def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend: def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
"""Map user's moe_backend string to Mxfp4MoeBackend.""" """Map user's moe_backend string to Mxfp4MoeBackend."""
mapping: dict[str, Mxfp4MoeBackend] = { mapping: dict[str, Mxfp4MoeBackend] = {
"deep_gemm": Mxfp4MoeBackend.DEEPGEMM_MXFP4,
"flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, "flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
"flashinfer_trtllm_afp8": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, "flashinfer_trtllm_afp8": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
"flashinfer_cutlass": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, "flashinfer_cutlass": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
"flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, "flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
"triton": Mxfp4MoeBackend.TRITON, "triton": Mxfp4MoeBackend.TRITON,
"triton_unfused": Mxfp4MoeBackend.TRITON_UNFUSED,
"marlin": Mxfp4MoeBackend.MARLIN, "marlin": Mxfp4MoeBackend.MARLIN,
"aiter": Mxfp4MoeBackend.AITER, "aiter": Mxfp4MoeBackend.AITER,
"xpu": Mxfp4MoeBackend.XPU, "xpu": Mxfp4MoeBackend.XPU,
...@@ -177,7 +190,7 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend: ...@@ -177,7 +190,7 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
) )
def _get_priority_backends() -> list[Mxfp4MoeBackend]: def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]:
""" """
Get available backends in priority order based on platform and config. Get available backends in priority order based on platform and config.
Only includes BF16 backends. MXFP8 backends are selected via env vars. Only includes BF16 backends. MXFP8 backends are selected via env vars.
...@@ -187,7 +200,9 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]: ...@@ -187,7 +200,9 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
Mxfp4MoeBackend.AITER, Mxfp4MoeBackend.AITER,
Mxfp4MoeBackend.TRITON, Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.TRITON_UNFUSED, # TRITON_UNFUSED has bug with MTP support
# TODO re-enable after kernel is fixed
# TRITON_UNFUSED
Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN,
Mxfp4MoeBackend.XPU, Mxfp4MoeBackend.XPU,
...@@ -196,8 +211,28 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]: ...@@ -196,8 +211,28 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
return _AVAILABLE_BACKENDS return _AVAILABLE_BACKENDS
def _get_priority_backends() -> list[Mxfp4MoeBackend]:
"""
Get available backends in priority order. SM100+ prefers DeepGEMM FP4 /
TRTLLM MXFP8; SM90 falls through to Triton_unfused or Marlin (the
backend-level ``is_supported_config`` check filters by device capability).
"""
_AVAILABLE_BACKENDS = [
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.DEEPGEMM_MXFP4,
# TRITON_UNFUSED has bug with MTP support
# TODO re-enable after kernel is fixed
# TRITON_UNFUSED
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
]
return _AVAILABLE_BACKENDS
def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None: def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
"""Map backend to its activation key (MXFP8 or None for BF16).""" """Map backend to its activation key (FP8, MXFP8, or None for BF16)."""
if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
return kFp8Dynamic128Sym
if backend in ( if backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
...@@ -290,7 +325,7 @@ def select_gpt_oss_mxfp4_moe_backend( ...@@ -290,7 +325,7 @@ def select_gpt_oss_mxfp4_moe_backend(
) )
# Select kernels in order of backend. # Select kernels in order of backend.
AVAILABLE_BACKENDS = _get_priority_backends() AVAILABLE_BACKENDS = _get_priority_backends_for_gpt_oss()
# Handle explicit FlashInfer MXFP4 BF16 configuration. # Handle explicit FlashInfer MXFP4 BF16 configuration.
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
...@@ -387,11 +422,95 @@ def select_gpt_oss_mxfp4_moe_backend( ...@@ -387,11 +422,95 @@ def select_gpt_oss_mxfp4_moe_backend(
return Mxfp4MoeBackend.NONE, None return Mxfp4MoeBackend.NONE, None
def select_mxfp4_moe_backend(
config: FusedMoEConfig,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]:
"""
Select the MXFP4 MoE backend with MXFP8 activation as top priority.
Falls back through BF16 and other backends.
"""
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
def _make_log_backend(backend: Mxfp4MoeBackend):
return f"Using '{backend.value}' Mxfp4 MoE backend."
def _make_log_unsupported(backend: Mxfp4MoeBackend, reason: str | None) -> str:
if reason:
return (
f"Mxfp4 MoE backend '{backend.value}' does not support the "
f"deployment configuration since {reason}."
)
return (
f"Mxfp4 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: Mxfp4MoeBackend,
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts]]:
reason: str | None = None
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Honor explicit moe_backend (e.g. "marlin", "triton_unfused") before
# falling back to the auto priority list.
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_mxfp4_backend(runner_backend)
if (
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
and requested_backend == Mxfp4MoeBackend.MARLIN
):
requested_backend = Mxfp4MoeBackend.BATCHED_MARLIN
return _return_or_raise(
requested_backend,
config,
kMxfp4Static,
_backend_activation_key(requested_backend),
activation_format,
)
# Iterate priority backends: TRTLLM MXFP8, then Triton.
for backend in _get_priority_backends():
activation_key = _backend_activation_key(backend)
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, kMxfp4Static, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No MXFP4 MoE backend supports the deployment configuration."
)
def mxfp4_round_up_hidden_size_and_intermediate_size( def mxfp4_round_up_hidden_size_and_intermediate_size(
backend: Mxfp4MoeBackend, hidden_size: int, intermediate_size: int backend: Mxfp4MoeBackend, hidden_size: int, intermediate_size: int
) -> tuple[int, int]: ) -> tuple[int, int]:
"""Round up hidden_size and intermediate_size based on backend requirements.""" """Round up hidden_size and intermediate_size based on backend requirements."""
if backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN): if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
# DeepGEMM requires M/N/K alignment
intermediate_size = round_up(intermediate_size, 128)
hidden_size = round_up(hidden_size, 128)
elif backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
intermediate_size = round_up(intermediate_size, 128) intermediate_size = round_up(intermediate_size, 128)
if current_platform.is_xpu(): if current_platform.is_xpu():
hidden_size = round_up(hidden_size, 128) hidden_size = round_up(hidden_size, 128)
...@@ -434,6 +553,20 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format( ...@@ -434,6 +553,20 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
]: ]:
"""Convert loaded weights into backend-specific kernel format.""" """Convert loaded weights into backend-specific kernel format."""
if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_upcast_e8m0_to_fp32,
)
return (
w13_weight.data,
w2_weight.data,
_upcast_e8m0_to_fp32(w13_weight_scale.data),
_upcast_e8m0_to_fp32(w2_weight_scale.data),
w13_bias,
w2_bias,
)
num_experts = w13_weight.shape[0] num_experts = w13_weight.shape[0]
intermediate_size = w13_weight.shape[1] // 2 intermediate_size = w13_weight.shape[1] // 2
hidden_size = w13_weight.shape[2] * 2 hidden_size = w13_weight.shape[2] * 2
...@@ -738,9 +871,10 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format( ...@@ -738,9 +871,10 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
elif mxfp4_backend in TRITON_BACKENDS: elif mxfp4_backend in TRITON_BACKENDS:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
assert w13_bias is not None and w2_bias is not None if w13_bias is not None:
w13_bias = w13_bias.to(torch.float32) w13_bias = w13_bias.to(torch.float32)
w2_bias = w2_bias.to(torch.float32) if w2_bias is not None:
w2_bias = w2_bias.to(torch.float32)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
w13_weight, w13_weight,
...@@ -797,15 +931,271 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format( ...@@ -797,15 +931,271 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
) )
def convert_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend: Mxfp4MoeBackend,
layer: torch.nn.Module,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
_cache_permute_indices: dict[torch.Size, torch.Tensor] | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor,
Union[torch.Tensor, "PrecisionConfig"],
Union[torch.Tensor, "PrecisionConfig"],
torch.Tensor | None,
torch.Tensor | None,
]:
"""Convert loaded weights into backend-specific kernel format.
Supports DeepGEMM, TRTLLM MXFP8, Triton and Marlin backends.
"""
if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_upcast_e8m0_to_fp32,
)
# Weights stay as uint8 packed FP4 — no layout change needed.
# Convert E8M0 uint8 scales to float32.
return (
w13_weight.data,
w2_weight.data,
_upcast_e8m0_to_fp32(w13_weight_scale.data),
_upcast_e8m0_to_fp32(w2_weight_scale.data),
w13_bias,
w2_bias,
)
if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_mxfp4_layer_for_marlin,
)
return prepare_moe_mxfp4_layer_for_marlin(
layer,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
num_experts = w13_weight.shape[0]
intermediate_size = w13_weight.shape[1] // 2
hidden_size = w13_weight.shape[2] * 2
sf_block_size = 32 # mxfp4 block size
if mxfp4_backend in TRTLLM_BACKENDS:
assert _cache_permute_indices is not None
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
w13_weight = w13_weight.data
w2_weight = w2_weight.data
w13_weight_scale = w13_weight_scale.data
w2_weight_scale = w2_weight_scale.data
if w13_bias is not None:
w13_bias = w13_bias.data.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.data.to(torch.float32)
# Swap w1/w3 and interleave to match TRTLLM SwiGLU convention.
# Standard loading gives contiguous [w1/gate, w3/up].
# TRTLLM kernel expects interleaved [w3_0, w1_0, w3_1, w1_1, ...].
w1_weight = w13_weight[:, :intermediate_size, :]
w3_weight = w13_weight[:, intermediate_size:, :]
w13_weight = torch.stack([w3_weight, w1_weight], dim=2).reshape(
w13_weight.shape
)
w1_scale = w13_weight_scale[:, :intermediate_size, :]
w3_scale = w13_weight_scale[:, intermediate_size:, :]
w13_weight_scale = torch.stack([w3_scale, w1_scale], dim=2).reshape(
w13_weight_scale.shape
)
if w13_bias is not None:
b1 = w13_bias[:, :intermediate_size]
b3 = w13_bias[:, intermediate_size:]
w13_bias = torch.stack([b3, b1], dim=2).reshape(w13_bias.shape)
# Shuffle weights and scaling factors for transposed mma output.
# Permute indices depend only on shape (cached by torch.Size),
# so compute once and apply to all experts via batched indexing.
epilogue_tile_m = 128
# w13 weight permute
w13_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_weight[0].view(torch.uint8),
epilogue_tile_m,
).to(w13_weight.device)
w13_weight = w13_weight.view(torch.uint8)[:, w13_perm].contiguous()
# w13 scale permute + interleave
w13_sf_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_weight_scale[0].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
).to(w13_weight_scale.device)
w13_s = w13_weight_scale.view(torch.uint8)[:, w13_sf_perm].contiguous()
E, N_s, K_s = w13_s.shape
w13_weight_scale = (
nvfp4_block_scale_interleave(w13_s.reshape(E * N_s, K_s))
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
# w2 weight permute
w2_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_weight[0].view(torch.uint8),
epilogue_tile_m,
).to(w2_weight.device)
w2_weight = w2_weight.view(torch.uint8)[:, w2_perm].contiguous()
# w2 scale permute + interleave
w2_sf_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_weight_scale[0].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
).to(w2_weight_scale.device)
w2_s = w2_weight_scale.view(torch.uint8)[:, w2_sf_perm].contiguous()
E2, N2_s, K2_s = w2_s.shape
w2_weight_scale = (
nvfp4_block_scale_interleave(w2_s.reshape(E2 * N2_s, K2_s))
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
# w13 bias permute
if w13_bias is not None:
w13_b_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_bias[0].reshape(-1, 1),
epilogue_tile_m,
).to(w13_bias.device)
w13_bias = w13_bias.reshape(num_experts, -1, 1)[:, w13_b_perm].reshape(
num_experts, -1
)
# w2 bias permute
if w2_bias is not None:
w2_b_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_bias[0].reshape(-1, 1),
epilogue_tile_m,
).to(w2_bias.device)
w2_bias = w2_bias.reshape(num_experts, -1, 1)[:, w2_b_perm].reshape(
num_experts, -1
)
return (
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
elif mxfp4_backend in TRITON_BACKENDS:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
if mxfp4_backend == Mxfp4MoeBackend.TRITON:
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
shape = w.shape
n = shape[-1]
first = w[..., : n // 2]
second = w[..., n // 2 :]
stacked = torch.stack((first, second), dim=-1)
return stacked.reshape(shape)
w13_weight = shuffle_weight(w13_weight)
w13_weight_scale = shuffle_weight(w13_weight_scale)
if w13_bias is not None:
w13_bias = shuffle_weight(w13_bias.to(torch.float32))
else:
if w13_bias is not None:
w13_bias = w13_bias.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.to(torch.float32)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
w13_weight,
w13_weight_scale,
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
w2_weight,
w2_weight_scale,
)
w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
del layer.w13_weight
del layer.w2_weight
return (
w13_weight,
w2_weight,
w13_precision_config,
w2_precision_config,
w13_bias,
w2_bias,
)
else:
raise ValueError(
f"Unsupported mxfp4_backend for Mxfp4MoEMethod: {mxfp4_backend}. "
f"Expected TRTLLM or Triton backend."
)
def make_mxfp4_moe_quant_config( def make_mxfp4_moe_quant_config(
mxfp4_backend: Mxfp4MoeBackend, mxfp4_backend: Mxfp4MoeBackend,
w1_scale: Union[torch.Tensor, "PrecisionConfig"], w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"],
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
swiglu_limit: float | None = None,
w1_bias: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
"""Create a FusedMoEQuantConfig for the given MXFP4 backend.""" """Create a FusedMoEQuantConfig for the given MXFP4 backend."""
if mxfp4_backend in ( if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
# DeepGEMM FP4 uses FP8 per-token-group activation quantization
# with block 128, matching the FP8 DeepGEMM path.
_fp8_dtype = current_platform.fp8_dtype()
_block_shape = GroupShape(128, 128)
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
_a2=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
elif mxfp4_backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
): ):
...@@ -814,6 +1204,9 @@ def make_mxfp4_moe_quant_config( ...@@ -814,6 +1204,9 @@ def make_mxfp4_moe_quant_config(
w2_bias=w2_bias, w2_bias=w2_bias,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
) )
elif mxfp4_backend in ( elif mxfp4_backend in (
Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.MARLIN,
...@@ -829,6 +1222,9 @@ def make_mxfp4_moe_quant_config( ...@@ -829,6 +1222,9 @@ def make_mxfp4_moe_quant_config(
w2_bias=w2_bias, w2_bias=w2_bias,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
) )
else: else:
return ocp_mx_moe_quant_config( return ocp_mx_moe_quant_config(
...@@ -837,6 +1233,9 @@ def make_mxfp4_moe_quant_config( ...@@ -837,6 +1233,9 @@ def make_mxfp4_moe_quant_config(
w2_bias=w2_bias, w2_bias=w2_bias,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
) )
......
...@@ -228,6 +228,8 @@ class BaseRouter(FusedMoERouter): ...@@ -228,6 +228,8 @@ class BaseRouter(FusedMoERouter):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
indices_type: torch.dtype | None, indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Compute the actual routing logic. Compute the actual routing logic.
...@@ -249,6 +251,8 @@ class BaseRouter(FusedMoERouter): ...@@ -249,6 +251,8 @@ class BaseRouter(FusedMoERouter):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Route the input hidden states to the top-k experts based on the Route the input hidden states to the top-k experts based on the
...@@ -278,7 +282,7 @@ class BaseRouter(FusedMoERouter): ...@@ -278,7 +282,7 @@ class BaseRouter(FusedMoERouter):
# Step 3: Compute routing (delegated to subclass) # Step 3: Compute routing (delegated to subclass)
topk_weights, topk_ids = self._compute_routing( topk_weights, topk_ids = self._compute_routing(
hidden_states, router_logits, indices_type hidden_states, router_logits, indices_type, input_ids=input_ids
) )
# Capture logical ids before EPLB mapping. # Capture logical ids before EPLB mapping.
......
...@@ -46,6 +46,8 @@ class CustomRoutingRouter(BaseRouter): ...@@ -46,6 +46,8 @@ class CustomRoutingRouter(BaseRouter):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
indices_type: torch.dtype | None, indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using the custom routing function.""" """Compute routing using the custom routing function."""
topk_weights, topk_ids = self.custom_routing_function( topk_weights, topk_ids = self.custom_routing_function(
......
...@@ -31,6 +31,8 @@ class FusedMoERouter(ABC): ...@@ -31,6 +31,8 @@ class FusedMoERouter(ABC):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Route the input hidden states to the top-k experts based on the Route the input hidden states to the top-k experts based on the
......
...@@ -4,6 +4,7 @@ import functools ...@@ -4,6 +4,7 @@ import functools
from collections.abc import Callable from collections.abc import Callable
import torch import torch
import torch.nn.functional as F
import vllm._custom_ops as ops import vllm._custom_ops as ops
import vllm.envs as envs import vllm.envs as envs
...@@ -56,6 +57,32 @@ def vllm_topk_sigmoid( ...@@ -56,6 +57,32 @@ def vllm_topk_sigmoid(
return topk_weights, topk_indices return topk_weights, topk_indices
def vllm_topk_softplus_sqrt(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
input_tokens: torch.Tensor | None = None,
hash_indices_table: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, ...]:
ops.topk_hash_softplus_sqrt(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
input_tokens,
hash_indices_table,
)
return topk_weights, topk_indices
@functools.lru_cache(maxsize=8) @functools.lru_cache(maxsize=8)
def _aiter_get_num_expert_group(num_experts: int) -> int: def _aiter_get_num_expert_group(num_experts: int) -> int:
_AITER_MAX_EXPERTS_PER_GROUP = 32 _AITER_MAX_EXPERTS_PER_GROUP = 32
...@@ -72,11 +99,14 @@ def _aiter_get_num_expert_group(num_experts: int) -> int: ...@@ -72,11 +99,14 @@ def _aiter_get_num_expert_group(num_experts: int) -> int:
def fused_topk_bias( def fused_topk_bias(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
scoring_func: str,
e_score_correction_bias: torch.Tensor, e_score_correction_bias: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
scoring_func: str = "softmax",
indices_type: torch.dtype | None = None, indices_type: torch.dtype | None = None,
input_tokens: torch.Tensor | None = None,
hash_indices_table: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
): ):
if not rocm_aiter_ops.is_fused_moe_enabled(): if not rocm_aiter_ops.is_fused_moe_enabled():
assert hidden_states.size(0) == gating_output.size(0), ( assert hidden_states.size(0) == gating_output.size(0), (
...@@ -107,6 +137,8 @@ def fused_topk_bias( ...@@ -107,6 +137,8 @@ def fused_topk_bias(
renormalize, renormalize,
e_score_correction_bias, e_score_correction_bias,
) )
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights, topk_ids return topk_weights, topk_ids
elif scoring_func == "sigmoid": elif scoring_func == "sigmoid":
topk_weights, topk_ids = vllm_topk_sigmoid( topk_weights, topk_ids = vllm_topk_sigmoid(
...@@ -117,9 +149,24 @@ def fused_topk_bias( ...@@ -117,9 +149,24 @@ def fused_topk_bias(
renormalize, renormalize,
e_score_correction_bias, e_score_correction_bias,
) )
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights, topk_ids return topk_weights, topk_ids
elif scoring_func == "sqrtsoftplus":
return vllm_topk_softplus_sqrt(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
input_tokens,
hash_indices_table,
routed_scaling_factor,
)
else: else:
raise ValueError(f"Unsupported scoring function: {scoring_func}") raise ValueError(f"Unsupported scoring function: {scoring_func}")
elif rocm_aiter_ops.is_fused_moe_enabled() and scoring_func == "sigmoid": elif rocm_aiter_ops.is_fused_moe_enabled() and scoring_func == "sigmoid":
M = hidden_states.size(0) M = hidden_states.size(0)
num_experts = gating_output.shape[-1] num_experts = gating_output.shape[-1]
...@@ -143,6 +190,8 @@ def fused_topk_bias( ...@@ -143,6 +190,8 @@ def fused_topk_bias(
topk_group=num_expert_group, topk_group=num_expert_group,
need_renorm=renormalize, need_renorm=renormalize,
) )
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights, topk_ids return topk_weights, topk_ids
n_routed_experts = gating_output.shape[-1] n_routed_experts = gating_output.shape[-1]
...@@ -150,20 +199,31 @@ def fused_topk_bias( ...@@ -150,20 +199,31 @@ def fused_topk_bias(
scores = gating_output.softmax(dim=-1) scores = gating_output.softmax(dim=-1)
elif scoring_func == "sigmoid": elif scoring_func == "sigmoid":
scores = gating_output.sigmoid() scores = gating_output.sigmoid()
elif scoring_func == "sqrtsoftplus":
scores = F.softplus(gating_output).sqrt()
else: else:
raise ValueError(f"Unsupported scoring function: {scoring_func}") raise ValueError(f"Unsupported scoring function: {scoring_func}")
if e_score_correction_bias is not None:
scores_for_choice = scores.view( scores_for_choice = scores.view(
-1, n_routed_experts -1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0) ) + e_score_correction_bias.unsqueeze(0)
else:
scores_for_choice = scores.view(-1, n_routed_experts)
# For batch invariance, use sorted=True to ensure deterministic expert selection # For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = envs.VLLM_BATCH_INVARIANT if hash_indices_table is not None:
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1] topk_indices = hash_indices_table[input_tokens]
else:
use_sorted = envs.VLLM_BATCH_INVARIANT
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[
1
]
topk_weights = scores.gather(1, topk_indices) topk_weights = scores.gather(1, topk_indices)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to( topk_weights = topk_weights.to(torch.float32)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights, topk_indices.to(
torch.int32 if indices_type is None else indices_type torch.int32 if indices_type is None else indices_type
) )
...@@ -176,12 +236,14 @@ class FusedTopKBiasRouter(BaseRouter): ...@@ -176,12 +236,14 @@ class FusedTopKBiasRouter(BaseRouter):
top_k: int, top_k: int,
global_num_experts: int, global_num_experts: int,
eplb_state: EplbLayerState, eplb_state: EplbLayerState,
e_score_correction_bias: torch.Tensor, e_score_correction_bias: torch.Tensor | None = None,
scoring_func: str,
renormalize: bool = True, renormalize: bool = True,
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
enable_eplb: bool = False, enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None,
*,
scoring_func: str = "sigmoid",
hash_indices_table: torch.Tensor | None = None,
): ):
super().__init__( super().__init__(
top_k=top_k, top_k=top_k,
...@@ -194,6 +256,8 @@ class FusedTopKBiasRouter(BaseRouter): ...@@ -194,6 +256,8 @@ class FusedTopKBiasRouter(BaseRouter):
self.renormalize = renormalize self.renormalize = renormalize
self.scoring_func = scoring_func self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self.scoring_func = scoring_func
self._hash_indices_table = hash_indices_table
@property @property
def routing_method_type(self) -> RoutingMethodType: def routing_method_type(self) -> RoutingMethodType:
...@@ -210,19 +274,23 @@ class FusedTopKBiasRouter(BaseRouter): ...@@ -210,19 +274,23 @@ class FusedTopKBiasRouter(BaseRouter):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
indices_type: torch.dtype | None, indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using fused top-k with bias.""" """Compute routing using fused top-k with bias."""
topk_weights, topk_ids = fused_topk_bias( topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias.data, scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias.data
if self.e_score_correction_bias is not None
else None,
topk=self.top_k, topk=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
scoring_func=self.scoring_func,
indices_type=indices_type, indices_type=indices_type,
input_tokens=input_ids,
hash_indices_table=self._hash_indices_table,
routed_scaling_factor=self.routed_scaling_factor,
) )
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -151,6 +151,8 @@ class FusedTopKRouter(BaseRouter): ...@@ -151,6 +151,8 @@ class FusedTopKRouter(BaseRouter):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
indices_type: torch.dtype | None, indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using standard fused top-k.""" """Compute routing using standard fused top-k."""
topk_weights, topk_ids, token_expert_indices = fused_topk( topk_weights, topk_ids, token_expert_indices = fused_topk(
......
...@@ -292,6 +292,8 @@ class GroupedTopKRouter(BaseRouter): ...@@ -292,6 +292,8 @@ class GroupedTopKRouter(BaseRouter):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
indices_type: torch.dtype | None, indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using grouped top-k.""" """Compute routing using grouped top-k."""
...@@ -308,6 +310,7 @@ class GroupedTopKRouter(BaseRouter): ...@@ -308,6 +310,7 @@ class GroupedTopKRouter(BaseRouter):
topk_weights, topk_ids = fused_topk_bias( topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias.data, e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k, topk=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
......
...@@ -55,6 +55,7 @@ def create_fused_moe_router( ...@@ -55,6 +55,7 @@ def create_fused_moe_router(
# zero expert parameters # zero expert parameters
zero_expert_type: str | None = None, zero_expert_type: str | None = None,
num_logical_experts: int | None = None, num_logical_experts: int | None = None,
hash_indices_table: torch.Tensor | None = None,
) -> FusedMoERouter: ) -> FusedMoERouter:
""" """
Factory function to create the appropriate FusedMoERouter subclass based on Factory function to create the appropriate FusedMoERouter subclass based on
...@@ -99,6 +100,9 @@ def create_fused_moe_router( ...@@ -99,6 +100,9 @@ def create_fused_moe_router(
num_logical_experts: Number of real (non-zero) experts. Required when num_logical_experts: Number of real (non-zero) experts. Required when
zero_expert_type is not None. zero_expert_type is not None.
Hash Indices Table:
Used to map input_ids to experts, need for Deepseek V4
Returns: Returns:
An instance of the appropriate FusedMoERouter subclass An instance of the appropriate FusedMoERouter subclass
""" """
...@@ -179,17 +183,20 @@ def create_fused_moe_router( ...@@ -179,17 +183,20 @@ def create_fused_moe_router(
indices_type_getter=indices_type_getter, indices_type_getter=indices_type_getter,
) )
if e_score_correction_bias is not None: assert scoring_func in ["sigmoid", "softmax", "sqrtsoftplus"]
if e_score_correction_bias is not None or hash_indices_table is not None:
return FusedTopKBiasRouter( return FusedTopKBiasRouter(
top_k=top_k, top_k=top_k,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
eplb_state=eplb_state, eplb_state=eplb_state,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
scoring_func=scoring_func,
renormalize=renormalize, renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter, indices_type_getter=indices_type_getter,
scoring_func=scoring_func,
hash_indices_table=hash_indices_table,
) )
return FusedTopKRouter( return FusedTopKRouter(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment