Unverified Commit 4d51588e authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Feat] DeepSeek V4 Rebased (#40860)


Signed-off-by: default avatarYifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarqizixi <zixi@inferact.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <yongye@inferact.ai>
Co-authored-by: default avatarSimon Mo <simon@inferact.ai>
Co-authored-by: default avatarBugen Zhao <i@bugenzhao.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoy Wang <yasong.wang@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarZhewen Li <jerven.vllm@gmail.com>
Co-authored-by: default avatarZijing Liu <liuzijing2014@gmail.com>
Co-authored-by: default avatarkhluu <khluu000@gmail.com>
Co-authored-by: default avatarqizixi <zixi@inferact.ai>
Co-authored-by: default avatarZhewen Li <zhewenli@inferact.ai>
parent 32e45636
......@@ -299,6 +299,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
tools: list[ChatCompletionFunctionToolParam] | None
"""The tools for developer role."""
task: str | None
"""Model-specific task marker. Currently passed through for DeepSeek V4."""
ChatCompletionMessageParam: TypeAlias = (
OpenAIChatCompletionMessageParam
......@@ -333,6 +336,9 @@ class ConversationMessage(TypedDict, total=False):
tools: list[ChatCompletionFunctionToolParam] | None
"""The tools for developer role."""
task: str | None
"""Model-specific task marker. Currently passed through for DeepSeek V4."""
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
......@@ -1566,6 +1572,9 @@ def _parse_chat_message_content(
if "name" in message and isinstance(message["name"], str):
result_msg["name"] = message["name"]
if "task" in message and isinstance(message["task"], str):
result_msg["task"] = message["task"]
if role == "developer":
result_msg["tools"] = message.get("tools", None)
return result
......
......@@ -100,6 +100,8 @@ class DeepGemmFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
else params.weight_scale,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=self.use_deep_gemm_e8m0,
is_bmm=getattr(layer, "is_bmm", False),
bmm_batch_size=getattr(layer, "bmm_batch_size", 0),
)
replace_parameter(layer, params.WEIGHT, dg_weight)
replace_parameter(layer, scale_attr, dg_weight_scale)
......
......@@ -1422,6 +1422,20 @@ class MLADims:
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
# Check if this is a DeepseekV4 config (uses unified head_dim + rope_head_dim)
if hasattr(hf_text_config, "compress_ratios"):
# DeepseekV4 style config: unified head_dim with rope_head_dim
head_dim = hf_text_config.head_dim
rope_head_dim = hf_text_config.qk_rope_head_dim
return MLADims(
q_lora_rank=hf_text_config.q_lora_rank,
kv_lora_rank=head_dim,
qk_nope_head_dim=head_dim - rope_head_dim,
qk_rope_head_dim=rope_head_dim,
v_head_dim=head_dim,
)
# DeepseekV2/V3 style config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
......@@ -2191,6 +2205,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
# DSV3.2 MLA Specific Arguments
indexer: object | None = None,
q_pad_num_heads: int | None = None,
) -> None:
......@@ -2213,6 +2228,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads
self.supports_quant_query_input = True
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
# Use flashinfer's optimized concat_mla_k kernel when available.
# The kernel is optimized for DeepSeek V3 dimensions:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, ClassVar, cast
import torch
from torch import nn
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
)
from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import (
_fused_kv_compress_norm_rope_insert_indexer_attn,
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
_fused_kv_compress_norm_rope_insert_sparse_attn,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
MXFP4_BLOCK_SIZE,
)
from vllm.v1.kv_cache_interface import (
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowMLASpec,
)
class CompressorBackend(AttentionBackend):
def __init__(self):
super().__init__()
@staticmethod
def get_name() -> str:
return "CompressorBackend"
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(1)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [512, 1024]
@staticmethod
def get_builder_cls() -> type["CompressorMetadataBuilder"]:
return CompressorMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
assert num_kv_heads == 1
return (num_blocks, block_size, head_size)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
if include_num_layers_dimension:
return (0, 1, 2, 3)
return (0, 1, 2)
@dataclass
class CompressorMetadata:
block_table: torch.Tensor
slot_mapping: torch.Tensor
block_size: int
token_to_req_indices: torch.Tensor | None = None # [num_tokens]
class CompressorMetadataBuilder(AttentionMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec)
mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec)
self.block_size = mla_spec.block_size
self.token_to_req_indices = torch.zeros(
self.vllm_config.scheduler_config.max_num_batched_tokens,
dtype=torch.int32,
device=self.device,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> CompressorMetadata:
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_reqs = common_attn_metadata.num_reqs
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
x = torch.repeat_interleave(torch.arange(num_reqs), query_lens).pin_memory()
token_to_req_indices = self.token_to_req_indices[: x.shape[0]]
token_to_req_indices.copy_(x, non_blocking=True)
return CompressorMetadata(
block_table=common_attn_metadata.block_table_tensor.clamp_(min=0),
slot_mapping=common_attn_metadata.slot_mapping,
block_size=self.block_size,
token_to_req_indices=token_to_req_indices,
)
class CompressorStateCache(torch.nn.Module, AttentionLayerBase):
def __init__(
self,
state_dim: int,
dtype: torch.dtype,
compress_ratio: int,
prefix: str,
):
super().__init__()
self.state_dim = state_dim
self.dtype = dtype
self.prefix = prefix
self.kv_cache = torch.tensor([])
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
assert self.dtype == torch.float32
assert compress_ratio in [4, 128]
coff = 1 + (compress_ratio == 4)
self.sliding_window = coff * compress_ratio
# Block size is constrained by tensor sharing between compressor states
# and KV blocks. Since compressor states share the same physical tensor
# as KV blocks, they must use the same page size.
# The KV block shape [256//4, head_dim] = [64, 584] determines:
# - C4 compressor block shape [4, 2*512*2*4] -> block_size = 4
# - C128 compressor block shape [8, 512*2*4] -> block_size = 8
# TODO(yifan): make block size automatically determined and configurable.
if compress_ratio == 4:
self.block_size = 4
elif compress_ratio == 128:
self.block_size = 8
else:
raise ValueError(f"Invalid compress ratio: {compress_ratio}")
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return SlidingWindowMLASpec( # only has one vector instead of K + V
block_size=self.block_size,
num_kv_heads=1,
head_size=self.state_dim,
dtype=self.dtype,
sliding_window=self.sliding_window,
alignment=576, # NOTE: FlashMLA requires 576B alignment
)
def forward(self): ...
def get_attn_backend(self) -> type[AttentionBackend]:
return CompressorBackend
class DeepseekCompressor(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
compress_ratio: int,
hidden_size: int,
head_dim: int,
rotate: bool = False,
prefix: str = "",
k_cache_prefix="",
use_fp4_cache: bool = False,
):
super().__init__()
self.compress_ratio = compress_ratio
self.hidden_size = hidden_size
self.head_dim = head_dim
self.rotate = rotate
self.prefix = prefix
self.k_cache_prefix = k_cache_prefix
self.use_fp4_cache = use_fp4_cache
config = vllm_config.model_config.hf_config
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.rms_norm_eps = config.rms_norm_eps
self.device = current_platform.device_type
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.max_model_len = vllm_config.model_config.max_model_len
self.overlap = compress_ratio == 4
self.coff = 1 + self.overlap
state_dtype = torch.float32
self.ape = nn.Parameter(
torch.empty(
(compress_ratio, self.coff * self.head_dim),
dtype=state_dtype,
device=self.device,
),
requires_grad=False,
)
self.fused_wkv_wgate = MergedColumnParallelLinear(
self.hidden_size,
[self.coff * self.head_dim, self.coff * self.head_dim],
bias=False,
return_bias=False,
quant_config=None,
disable_tp=True,
prefix=f"{prefix}.fused_wkv_wgate",
)
self.norm = RMSNorm(self.head_dim, self.rms_norm_eps)
self.state_cache = CompressorStateCache(
state_dim=2 * self.coff * self.head_dim, # kv_state + score_state
dtype=state_dtype,
compress_ratio=compress_ratio,
prefix=f"{prefix}.state_cache",
)
# Save reference to static_forward_context for forward-time KV cache lookup.
# get_current_vllm_config() is only available during __init__, not forward.
self._static_forward_context = (
vllm_config.compilation_config.static_forward_context
)
if self.head_dim == 512:
assert not use_fp4_cache, (
"MXFP4 cache is only supported for indexer (head=128)"
)
self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
self._quant_block = 64
self._token_stride = self.nope_head_dim + self.rope_head_dim * 2
self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad
self._num_warps = 4
elif self.head_dim == 128:
if use_fp4_cache:
self._fused_kernel = (
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
)
self._quant_block = MXFP4_BLOCK_SIZE
self._token_stride = self.head_dim // 2
self._scale_dim = self.head_dim // MXFP4_BLOCK_SIZE
else:
self._fused_kernel = _fused_kv_compress_norm_rope_insert_indexer_attn
self._quant_block = 128
self._token_stride = self.head_dim
self._scale_dim = 4 # single float32 scale
self._num_warps = 1
else:
raise ValueError(
f"Unsupported head_dim for fused quant+cache: {self.head_dim}"
)
def forward(
self,
# [num_tokens, hidden_size]
x: torch.Tensor,
# [num_tokens]
positions: torch.Tensor,
rotary_emb,
) -> None:
num_tokens, _ = x.shape
# bf16 weights/activations but fp32 output for numerical stability of
# the downstream compressor math.
kv_score = cublas_gemm_bf16_bf16_fp32(x, self.fused_wkv_wgate.weight)
# Each of shape [num_tokens, coff * self.head_dim]
# input bf16, output are fp32
kv, score = kv_score.split(
[self.coff * self.head_dim, self.coff * self.head_dim], dim=-1
)
# Get the metadata and handle dummy profiling run.
attn_metadata = get_forward_context().attn_metadata
if not isinstance(attn_metadata, dict):
return
state_metadata = cast(
CompressorMetadata, attn_metadata[self.state_cache.prefix]
)
token_to_req_indices = state_metadata.token_to_req_indices
slot_mapping = state_metadata.slot_mapping
num_actual = slot_mapping.shape[0]
block_table = state_metadata.block_table
block_size = state_metadata.block_size
# [num_blocks, block_size, kv_dim+score_dim], where kv_dim == score_dim
state_cache = self.state_cache.kv_cache
# kv_state stored in first half, score_state stored in second half
state_width = state_cache.shape[-1] // 2
# Store the KV and score (with fused APE addition) in the state.
# NOTE: PDL is disabled — both this kernel and _fused_kernel below
# depend on preceding kernel outputs (kv/score from the cublas GEMM;
# state_cache from this kernel) but neither emits/waits on PDL grid
# dependency primitives, so launch_pdl=True caused a read-after-write
# race and non-deterministic output.
_save_partial_states_kernel[(num_actual,)](
kv,
kv.stride(0),
score,
score.stride(0),
self.ape,
self.ape.stride(0),
positions,
state_cache,
state_cache.stride(0),
state_cache.stride(1),
slot_mapping,
block_size,
HEAD_SIZE=kv.shape[-1],
TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
launch_pdl=False,
)
# Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write.
# RoPE requirements (kernel applies forward GPT-J style rotation):
# - is_neox_style=False (interleaved pairs, NOT split-half)
# - cos_sin_cache layout: [max_pos, rope_head_dim] with first half cos,
# second half sin (per-pair, length rope_head_dim // 2 each)
# - applied to LAST rope_head_dim elements of head_dim
# - position used: (positions // compress_ratio) * compress_ratio
cos_sin_cache = rotary_emb.cos_sin_cache
k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix])
kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache
self._fused_kernel[(num_actual,)](
# state cache
state_cache,
state_cache.stride(0),
state_cache.stride(1),
# metadata
token_to_req_indices,
positions,
slot_mapping,
block_table,
block_table.stride(0),
block_size,
# RMSNorm
self.norm.weight,
self.rms_norm_eps,
# RoPE
cos_sin_cache,
cos_sin_cache.stride(0),
# KV cache
kv_cache,
k_cache_metadata.slot_mapping,
kv_cache.shape[1], # paged KV cache block size (tokens per block)
# constexprs
HEAD_SIZE=self.head_dim,
TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
OVERLAP=self.overlap,
ROPE_HEAD_DIM=self.rope_head_dim,
FP8_MAX=448.0,
QUANT_BLOCK=self._quant_block,
TOKEN_STRIDE=self._token_stride,
SCALE_DIM=self._scale_dim,
KV_BLOCK_STRIDE=kv_cache.stride(0),
num_warps=self._num_warps,
launch_pdl=False,
)
@triton.jit
def _save_partial_states_kernel(
kv_ptr,
kv_stride,
score_ptr,
score_stride,
ape_ptr,
ape_stride,
positions_ptr,
state_cache_ptr,
state_cache_stride0,
state_cache_stride1,
slot_mapping_ptr,
block_size,
HEAD_SIZE: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
# state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide.
STATE_WIDTH: tl.constexpr,
COMPRESS_RATIO: tl.constexpr,
):
token_idx = tl.program_id(0)
slot_id = tl.load(slot_mapping_ptr + token_idx)
# Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used
# by vLLM). During CUDA graph replay the batch may contain padding
# tokens whose slot_mapping is -1; writing to kv_state[-1] would be an
# illegal memory access.
if slot_id < 0:
return
block_idx = slot_id // block_size
pos_in_block = slot_id % block_size
base_ptr = (
state_cache_ptr
+ block_idx * state_cache_stride0
+ pos_in_block * state_cache_stride1
)
block = tl.arange(0, TRITON_BLOCK_SIZE)
mask = block < HEAD_SIZE
kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask)
tl.store(base_ptr + block, kv, mask=mask)
# Fused: score += ape[position % compress_ratio]
position = tl.load(positions_ptr + token_idx)
ape_row = position % COMPRESS_RATIO
ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask)
score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask)
tl.store(
base_ptr + STATE_WIDTH + block,
score + ape,
mask=mask,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
DeepseekV4 MLA Attention Layer
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import DeepseekV2Config, DeepseekV3Config
from vllm.model_executor.layers.linear import (
ReplicatedLinear,
)
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
from vllm.utils.deep_gemm import fp8_einsum
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.ops.deepseek_v4_ops import (
combine_topk_swa_indices,
compute_global_topk_indices_and_lens,
dequantize_and_gather_k_cache,
fused_indexer_q_rope_quant,
fused_inv_rope_fp8_quant,
fused_q_kv_rmsnorm,
)
if TYPE_CHECKING:
from vllm.v1.attention.backends.mla.sparse_swa import (
DeepseekSparseSWAMetadata,
)
from vllm.config import (
CacheConfig,
VllmConfig,
get_current_vllm_config,
)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
QuantFP8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.utils.multi_stream_utils import maybe_execute_in_parallel
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backends.mla.flashmla_sparse import (
DeepseekV4FlashMLASparseBackend,
FlashMLASparseBackend,
FlashMLASparseMetadata,
)
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV4IndexerBackend,
get_max_prefill_buffer_size,
)
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache
from vllm.v1.attention.ops.flashmla import (
flash_mla_sparse_fwd,
flash_mla_with_kvcache,
)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
logger = init_logger(__name__)
# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather
# workspace allocated at _forward_prefill (and the matching profile-time
# reservation in attention_impl's dummy-run branch).
PREFILL_CHUNK_SIZE = 4
@dataclass
class DeepseekV4MLAModules:
"""Modules used in DeepseekV4 MLA."""
vllm_config: VllmConfig
fused_wqa_wkv: torch.nn.Module
q_norm: torch.nn.Module
wq_b: torch.nn.Module
kv_norm: torch.nn.Module
wo_a: torch.nn.Module
wo_b: torch.nn.Module
attn_sink: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: torch.nn.Module | None
indexer_rotary_emb: torch.nn.Module
topk_indices_buffer: torch.Tensor | None
aux_stream: torch.cuda.Stream | None = None
# --8<-- [start:multi_head_latent_attention]
@PluggableLayer.register("deepseek_v4_multi_head_latent_attention")
class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
"""Pluggable MLA layer which allows OOT backends to add
custom implementations of the outer MLA layer (including rope & o_proj).
Note that currently oot platforms can still use CustomOp.register_oot to
replace MLA layer entirely, although we use PluggableLayer to register
this layer now.
This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens.
The class does the following:
1. MLA Preprocess.
2. Perform multi-head attention to prefill tokens and
multi-query attention to decode tokens separately.
3. Return the output tensor.
"""
# --8<-- [end:multi_head_latent_attention]
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
o_lora_rank: int | None,
mla_modules: DeepseekV4MLAModules,
window_size: int,
compress_ratio: int | None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.n_local_heads = num_heads
self.head_dim = head_dim
self.scale = scale
# FlashMLA sparse kernel only supports 64 or 128 heads; pad up to the
# next supported size. Must match DeepseekV4MLAAttention.padded_heads.
if num_heads <= 64:
self.padded_heads = 64
elif num_heads <= 128:
self.padded_heads = 128
else:
raise ValueError(
f"DeepseekV4 attention does not support {num_heads} heads "
"(must be <= 128)."
)
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.window_size = window_size
self.compress_ratio = compress_ratio if compress_ratio is not None else 1
self.prefix = prefix
# Extract config from vllm_config
config = mla_modules.vllm_config.model_config.hf_config
tp_size = get_tensor_model_parallel_world_size()
# DeepseekV4-specific attributes (num_heads is already TP-adjusted)
self.eps = config.rms_norm_eps
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = head_dim - self.rope_head_dim
self.n_local_groups = config.o_groups // tp_size
self.o_lora_rank = config.o_lora_rank
# Store projection modules
self.fused_wqa_wkv = mla_modules.fused_wqa_wkv
self.q_norm = mla_modules.q_norm
self.wq_b = mla_modules.wq_b
self.kv_norm = mla_modules.kv_norm
self.wo_a = mla_modules.wo_a
self._wo_a_act_quant = QuantFP8(
static=False,
group_shape=GroupShape(1, 128),
use_ue8m0=True,
)
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
# INT32) so fp8_einsum can handle layout transform internally.
self._wo_a_act_quant.use_deep_gemm_supported = False
self.wo_b = mla_modules.wo_b
# Pick fp8_einsum recipe based on GPU arch:
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
# SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1
from vllm.platforms import current_platform
cap = current_platform.get_device_capability()
assert cap is not None, "DeepseekV4 attention requires a CUDA device"
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
self._tma_aligned_scales = cap.major >= 10
self.rotary_emb = mla_modules.rotary_emb
self.indexer_rotary_emb = mla_modules.indexer_rotary_emb
self.topk_indices_buffer = mla_modules.topk_indices_buffer
self.indexer = mla_modules.indexer
# Per-head RMS normalization for Q (no learnable weights)
self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False)
# TODO(yifan): currently hardcoded for FP8 sparse, make it more generic
head_bytes = (
self.nope_head_dim # 448 fp8 NoPE
+ self.rope_head_dim * 2 # 64 bf16 RoPE
+ self.nope_head_dim // 64 # 7B scale factors
+ 1 # 1B pad
)
self.aux_stream = mla_modules.aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
assert cache_config is not None, "DeepseekV4 attention requires cache_config"
self.swa_cache_layer = DeepseekV4SWACache(
head_dim=self.head_dim,
window_size=self.window_size,
dtype=torch.uint8,
prefix=f"{prefix}.swa_cache",
cache_config=cache_config,
)
self.mla_attn = DeepseekV4MLAAttention(
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
compress_ratio=self.compress_ratio,
window_size=self.window_size,
head_bytes=head_bytes,
swa_cache_layer=self.swa_cache_layer,
attn_sink=mla_modules.attn_sink, # already padded with -inf
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
indexer=self.indexer,
topk_indices_buffer=self.topk_indices_buffer,
)
# Register this layer in the compilation config's static forward context
# This allows the custom op to retrieve the layer during execution
compilation_config = mla_modules.vllm_config.compilation_config
# HACK
self.layer_name = prefix + ".deepseek_v4_multi_head_latent_attention"
if self.layer_name in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {self.layer_name}")
compilation_config.static_forward_context[self.layer_name] = self
# Create the compressor for layers with compress_ratio > 1; after
# creating the DeepseekV4MLAAttention layer to get its cache.
self.compressor = None
if self.compress_ratio > 1:
self.compressor = DeepseekCompressor(
vllm_config=mla_modules.vllm_config,
compress_ratio=self.compress_ratio,
hidden_size=self.hidden_size,
head_dim=self.head_dim,
rotate=True,
prefix=f"{prefix}.compressor",
k_cache_prefix=self.mla_attn.prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
qr_kv, _ = self.fused_wqa_wkv(hidden_states)
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
# Pre-allocate attention output with FlashMLA-padded head count.
# The op writes into `o_padded`; we slice to n_local_heads after.
num_tokens = hidden_states.shape[0]
o_padded = torch.empty(
(num_tokens, self.padded_heads, self.head_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
# Attention (inside custom op for torch.compile boundary)
torch.ops.vllm.deepseek_v4_attention(
hidden_states,
qr,
kv,
positions,
o_padded,
self.layer_name,
)
o = o_padded[:, : self.n_local_heads, :]
# O projection: inverse RoPE + FP8 quant + einsum + wo_b
o_fp8, o_scale = fused_inv_rope_fp8_quant(
o,
positions,
self.rotary_emb.cos_sin_cache,
n_groups=self.n_local_groups,
heads_per_group=self.n_local_heads // self.n_local_groups,
nope_dim=self.nope_head_dim,
rope_dim=self.rope_head_dim,
tma_aligned_scales=self._tma_aligned_scales,
)
wo_a_fp8 = self.wo_a.weight
wo_a_scale = self.wo_a.weight_scale_inv
z = torch.empty(
(num_tokens, self.n_local_groups, self.o_lora_rank),
device=o.device,
dtype=torch.bfloat16,
)
torch.ops.vllm.deepseek_v4_fp8_einsum(
o_fp8,
o_scale,
wo_a_fp8,
wo_a_scale,
z,
"bhr,hdr->bhd",
list(self._einsum_recipe),
)
return self.wo_b(z.flatten(1))
def attention_impl(
self,
hidden_states: torch.Tensor,
qr: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place
) -> None:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
qr, kv = fused_q_kv_rmsnorm(
qr,
kv,
self.q_norm.weight.data,
self.kv_norm.weight.data,
self.eps,
)
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
# Overlap kv_insert with whichever of indexer/compressor is present.
# Indexer implies compressor; when both exist, compressor rides on the
# aux stream alongside kv_insert so the heavy indexer owns default.
if self.indexer is not None:
indexer = self.indexer
# Local ref so the closure keeps a non-None type for mypy.
assert self.compressor is not None
compressor = self.compressor
def kv_insert_and_compress() -> None:
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
compressor(hidden_states, positions, self.rotary_emb)
maybe_execute_in_parallel(
lambda: indexer(hidden_states, qr, positions, self.indexer_rotary_emb),
kv_insert_and_compress,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
elif self.compressor is not None:
# Compressor on default, kv_insert on aux.
compressor = self.compressor
maybe_execute_in_parallel(
lambda: compressor(hidden_states, positions, self.rotary_emb),
lambda: self._fused_qnorm_rope_kv_insert(
q, kv, positions, attn_metadata
),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
else:
# SWA-only layer: no compressor, no overlap.
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
# Handle dummy run (no metadata).
if not isinstance(attn_metadata, dict):
# Reserve _forward_prefill's bf16-gather workspace; the dummy
# run returns before mla_attn runs, so without this the shared
# workspace locks below the real prefill size.
sub = self.mla_attn
swa_only = sub.compress_ratio <= 1
N = (
0
if swa_only
else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio
)
M = N + sub.window_size + sub.max_num_batched_tokens
current_workspace_manager().get_simultaneous(
((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)
out.zero_()
return
# Pad q to FlashMLA-required head count (64 or 128)
if self.n_local_heads < self.padded_heads:
pad_size = self.padded_heads - self.n_local_heads
q = F.pad(q, (0, 0, 0, pad_size), value=0.0)
# MLA attention writes into the pre-allocated `out` buffer
# ([num_tokens, padded_heads, head_dim]).
self.mla_attn(q, kv, positions, output=out)
def _fused_qnorm_rope_kv_insert(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
attn_metadata: (
dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None
),
) -> None:
if not isinstance(attn_metadata, dict):
return
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
swa_kv_cache = self.swa_cache_layer.kv_cache
swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1)
# Horizontally fused:
# Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE
# KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert
# kv is unchanged; mla_attn reads kv solely via swa_kv_cache.
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
q,
kv,
swa_kv_cache_2d,
swa_metadata.slot_mapping,
positions.to(torch.int64),
self.rotary_emb.cos_sin_cache,
self.eps,
swa_metadata.block_size,
)
def deepseek_v4_attention(
hidden_states: torch.Tensor,
qr: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.attention_impl(hidden_states, qr, kv, positions, out)
def deepseek_v4_attention_fake(
hidden_states: torch.Tensor,
qr: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor,
layer_name: str,
) -> None:
return None
direct_register_custom_op(
op_name="deepseek_v4_attention",
op_func=deepseek_v4_attention,
mutates_args=["out"],
fake_impl=deepseek_v4_attention_fake,
)
def deepseek_v4_fp8_einsum(
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
out: torch.Tensor,
equation: str,
recipe: list[int],
) -> None:
fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe))
def deepseek_v4_fp8_einsum_fake(
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
out: torch.Tensor,
equation: str,
recipe: list[int],
) -> None:
return None
direct_register_custom_op(
op_name="deepseek_v4_fp8_einsum",
op_func=deepseek_v4_fp8_einsum,
mutates_args=["out"],
fake_impl=deepseek_v4_fp8_einsum_fake,
)
class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
# FlashMLA FP8 sparse only supports 64 or 128 heads
SUPPORTED_HEAD_COUNTS = (64, 128)
def __init__(
self,
num_heads: int,
head_dim: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
compress_ratio: int,
window_size: int,
head_bytes: int,
swa_cache_layer: DeepseekV4SWACache,
attn_sink: torch.Tensor,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
# Sparse MLA Args
indexer: object | None = None,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream: torch.cuda.Stream | None = None,
**extra_impl_args,
) -> None:
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = 1
self.head_dim = head_dim
self.scale = scale
self.window_size = window_size
self.head_bytes = head_bytes
self.compress_ratio = compress_ratio
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.nope_head_dim = qk_nope_head_dim
self.rope_head_dim = qk_rope_head_dim
self.indexer = indexer
self.topk_indices_buffer = topk_indices_buffer
self.prefix = prefix # Alias for compatibility with compressor
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
# Determine padded head count for FlashMLA
if num_heads not in self.SUPPORTED_HEAD_COUNTS:
if num_heads < 64:
self.padded_heads = 64
elif num_heads < 128:
self.padded_heads = 128
else:
raise ValueError(
f"DeepseekV4MLAAttention does not support {num_heads} heads. "
f"Supported: <= 128 (will be padded to 64 or 128)"
)
else:
self.padded_heads = num_heads
# Store attention sink
assert attn_sink is not None
self.attn_sink: torch.Tensor = attn_sink
# Store SWA cache
assert swa_cache_layer is not None
self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer
# Get vllm config for cache setup
vllm_config = get_current_vllm_config()
self.max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens
)
self.max_model_len = vllm_config.model_config.max_model_len
# DeepseekV4 only supports fp8 kv-cache format for now
kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8"
assert kv_cache_dtype.startswith("fp8"), (
f"DeepseekV4 only supports fp8 kv-cache format for now, "
f"got {kv_cache_dtype}"
)
assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), (
"Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now"
)
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if (
issubclass(self.get_attn_backend(), FlashMLASparseBackend)
and kv_cache_dtype.startswith("fp8")
and kv_cache_dtype != "fp8_ds_mla"
):
assert cache_config is not None
cache_config.cache_dtype = "fp8_ds_mla"
kv_cache_dtype = "fp8_ds_mla"
logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
self.kv_cache_dtype = kv_cache_dtype
# Register with compilation context for metadata lookup
compilation_config = vllm_config.compilation_config
if prefix and prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
if prefix:
compilation_config.static_forward_context[prefix] = self
self.kv_cache = torch.tensor([])
def get_attn_backend(self) -> type[AttentionBackend]:
return DeepseekV4FlashMLASparseBackend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
if (
self.compress_ratio <= 1
): # SWA part. Allocated separately as DeepseekV4SWACache.
return None
return MLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=torch.uint8,
compress_ratio=self.compress_ratio,
cache_dtype_str=self.kv_cache_dtype,
alignment=576, # NOTE: FlashMLA requires 576B alignment
model_version="deepseek_v4",
)
def forward(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
assert output.shape == q.shape, (
f"output buffer shape {output.shape} must match q shape {q.shape}"
)
assert output.dtype == q.dtype, (
f"output buffer dtype {output.dtype} must match q dtype {q.dtype}"
)
# Get SWA and indexer metadata from forward context
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
assert isinstance(attn_metadata, dict)
flashmla_metadata = cast(
FlashMLASparseMetadata | None, attn_metadata.get(self.prefix)
)
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
swa_only = self.compress_ratio <= 1
# SWA-only layers (compress_ratio <= 1) don't have their own KV cache
# allocation, so self.kv_cache may be empty after profiling cleanup.
self_kv_cache = self.kv_cache if not swa_only else None
swa_kv_cache = self.swa_cache_layer.kv_cache
# Split prefill and decode
num_decodes = swa_metadata.num_decodes
num_prefills = swa_metadata.num_prefills
num_decode_tokens = swa_metadata.num_decode_tokens
if num_prefills > 0:
self._forward_prefill(
q=q[num_decode_tokens:],
positions=positions[num_decode_tokens:],
compressed_k_cache=self_kv_cache,
swa_k_cache=swa_kv_cache,
output=output[num_decode_tokens:],
attn_metadata=flashmla_metadata,
swa_metadata=swa_metadata,
)
if num_decodes > 0:
self._forward_decode(
q=q[:num_decode_tokens],
kv_cache=self_kv_cache,
swa_metadata=swa_metadata,
attn_metadata=flashmla_metadata,
swa_only=swa_only,
output=output[:num_decode_tokens],
)
def _forward_decode(
self,
q: torch.Tensor,
kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1
swa_metadata: "DeepseekSparseSWAMetadata",
attn_metadata: FlashMLASparseMetadata | None,
swa_only: bool,
output: torch.Tensor,
) -> None:
num_decodes = swa_metadata.num_decodes
num_decode_tokens = swa_metadata.num_decode_tokens
topk_indices = None
topk_lens = None
if not swa_only:
assert attn_metadata is not None
assert swa_metadata.is_valid_token is not None
block_size = attn_metadata.block_size // self.compress_ratio
is_valid = swa_metadata.is_valid_token[:num_decode_tokens]
if self.compress_ratio == 4:
# C4A: local indices differ per layer (filled by Indexer).
assert self.topk_indices_buffer is not None
global_indices, topk_lens = compute_global_topk_indices_and_lens(
self.topk_indices_buffer[:num_decode_tokens],
swa_metadata.token_to_req_indices,
attn_metadata.block_table[:num_decodes],
block_size,
is_valid,
)
topk_indices = global_indices.view(num_decode_tokens, 1, -1)
else:
# C128A: pre-computed during metadata build.
topk_indices = attn_metadata.c128a_global_decode_topk_indices
topk_lens = attn_metadata.c128a_decode_topk_lens
swa_indices = swa_metadata.decode_swa_indices
swa_lens = swa_metadata.decode_swa_lens
# We treat queries in the same seq as different queries
# and later we only attend by generated indices.
# q arrives pre-padded to self.padded_heads by the outer wrapper.
q = q.unsqueeze(1)
# Prepare SWA cache (num_blocks, swa_block_size, 1, head_bytes)
# Use unsqueeze to preserve strides (handles padded blocks correctly)
swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2)
# Reshape KV cache to (num_blocks, block_size, 1, head_bytes)
if kv_cache is not None:
kv_cache = kv_cache.unsqueeze(-2)
# One FlashMLASchedMeta per layer type, shared across all same-type
# layers within this decode step. The first forward call per type
# triggers the in-kernel planner (allocating tile_scheduler_metadata
# and num_splits via PyTorch's graph-aware allocator so CUDA graph
# capture reuses the same addresses on replay); subsequent same-type
# layers see have_initialized=True and skip the planner.
if self.compress_ratio <= 1:
tile_metadata = swa_metadata.tile_sched_swaonly
elif self.compress_ratio == 4:
tile_metadata = swa_metadata.tile_sched_c4a
elif self.compress_ratio == 128:
tile_metadata = swa_metadata.tile_sched_c128a
else:
raise ValueError(
f"Unsupported compress_ratio={self.compress_ratio}; "
"expected 1, 4, or 128."
)
assert tile_metadata is not None, (
"swa_metadata missing tile_sched entry for "
f"compress_ratio={self.compress_ratio}; "
"DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not "
"allocate one for this layer type."
)
out, _ = flash_mla_with_kvcache(
q=q,
k_cache=swa_cache,
block_table=None,
head_dim_v=512,
tile_scheduler_metadata=tile_metadata,
cache_seqlens=None,
is_fp8_kvcache=True,
indices=swa_indices,
topk_length=swa_lens,
softmax_scale=self.scale,
attn_sink=self.attn_sink,
extra_k_cache=kv_cache if not swa_only else None,
extra_indices_in_kvcache=topk_indices,
extra_topk_length=topk_lens,
out=output.unsqueeze(1),
)
def _forward_prefill(
self,
q: torch.Tensor,
positions: torch.Tensor,
compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1
swa_k_cache: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashMLASparseMetadata | None,
swa_metadata: "DeepseekSparseSWAMetadata",
) -> None:
swa_only = attn_metadata is None
num_prefills = swa_metadata.num_prefills
num_prefill_tokens = swa_metadata.num_prefill_tokens
num_decodes = swa_metadata.num_decodes
num_decode_tokens = swa_metadata.num_decode_tokens
# Use pre-computed prefill metadata.
seq_lens = swa_metadata.prefill_seq_lens
gather_lens = swa_metadata.prefill_gather_lens
assert seq_lens is not None
assert gather_lens is not None
# Derive prefill-local token offsets from the full query_start_loc_cpu.
query_start_loc_cpu = swa_metadata.query_start_loc_cpu
query_start_loc = swa_metadata.query_start_loc
assert query_start_loc_cpu is not None
assert query_start_loc is not None
prefill_token_base = query_start_loc_cpu[num_decodes]
if not swa_only:
if self.compress_ratio == 4:
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
topk_indices = topk_indices[:num_prefill_tokens]
else:
# C128A: pre-computed during metadata build.
assert attn_metadata is not None
topk_indices = attn_metadata.c128a_prefill_topk_indices
top_k = topk_indices.shape[-1]
# Compressed region must fit the full compressed pool (seq_len //
# compress_ratio), not just top_k. top_k bounds how many indices
# the indexer selects, not the pool size it indexes into.
N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio
else:
# NOTE(woosuk): topk_indices will not be used for SWA-only layers.
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
top_k = 0
N = 0
M = N + self.window_size + self.max_num_batched_tokens
num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE
workspace_manager = current_workspace_manager()
kv = workspace_manager.get_simultaneous(
((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)[0]
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * PREFILL_CHUNK_SIZE
chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills)
chunk_size = chunk_end - chunk_start
if not swa_only:
# Gather compressed KV
assert attn_metadata is not None
block_table = attn_metadata.block_table[num_decodes:]
dequantize_and_gather_k_cache(
kv[:chunk_size],
compressed_k_cache,
seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio,
gather_lens=None,
block_table=block_table[chunk_start:chunk_end],
block_size=attn_metadata.block_size // self.compress_ratio,
offset=0,
)
# Gather SWA KV
swa_block_table = swa_metadata.block_table[num_decodes:]
dequantize_and_gather_k_cache(
kv[:chunk_size],
swa_k_cache,
seq_lens=seq_lens[chunk_start:chunk_end],
gather_lens=gather_lens[chunk_start:chunk_end],
block_table=swa_block_table[chunk_start:chunk_end],
block_size=swa_metadata.block_size,
offset=N,
)
# Combine the topk indices and SWA indices for gathered KV cache
query_start = (
query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base
)
query_end = (
query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base
)
combined_indices, combined_lens = combine_topk_swa_indices(
topk_indices[query_start:query_end],
query_start_loc[
num_decodes + chunk_start : num_decodes + chunk_end + 1
],
seq_lens[chunk_start:chunk_end],
gather_lens[chunk_start:chunk_end],
self.window_size,
self.compress_ratio,
top_k,
M,
N,
)
output_chunk, _, _ = flash_mla_sparse_fwd(
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
sm_scale=self.scale,
attn_sink=self.attn_sink,
topk_length=combined_lens,
out=output[query_start:query_end],
)
class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):
def __init__(
self,
head_dim: int,
dtype: torch.dtype,
prefix: str,
cache_config: CacheConfig,
compress_ratio: int = 1,
):
super().__init__()
self.kv_cache = torch.tensor([])
self.head_dim = head_dim
self.prefix = prefix
self.cache_config = cache_config
self.dtype = dtype
self.compress_ratio = compress_ratio
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# head_dim already carries the fp8 scale padding
# compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout.
return MLAAttentionSpec(
block_size=self.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
compress_ratio=self.compress_ratio,
# DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with
# the indexer's compressor state cache. V3.2 keeps the legacy layout.
alignment=576,
)
def forward(self): ...
def get_attn_backend(self) -> type[AttentionBackend]:
return DeepseekV4IndexerBackend
class DeepseekV4Indexer(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
config: DeepseekV2Config | DeepseekV3Config,
hidden_size: int,
q_lora_rank: int,
quant_config: QuantizationConfig | None,
cache_config: CacheConfig | None,
topk_indices_buffer: torch.Tensor | None,
compress_ratio: int = 1,
prefix: str = "",
):
super().__init__()
self.vllm_config = vllm_config
self.config = config
self.quant_config = quant_config
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
self.topk_tokens = config.index_topk
self.n_head = config.index_n_heads # 64
self.head_dim = config.index_head_dim # 128
self.rope_dim = config.qk_rope_head_dim # 64
self.q_lora_rank = q_lora_rank # 1536
self.compress_ratio = compress_ratio
self.use_fp4_kv = self.vllm_config.attention_config.use_fp4_indexer_cache
logger.info_once(
"Using %s indexer cache for Lighening Indexer.",
"MXFP4" if self.use_fp4_kv else "FP8",
)
# no tensor parallel, just replicated
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.head_dim * self.n_head,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
)
self.weights_proj = ReplicatedLinear(
hidden_size,
self.n_head,
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = "ue8m0"
self.quant_block_size = 128 # TODO: get from config
self.topk_indices_buffer = topk_indices_buffer
self.max_model_len = (
vllm_config.model_config.max_model_len // self.compress_ratio
)
self.prefix = prefix
self.max_total_seq_len = (
get_max_prefill_buffer_size(vllm_config) // self.compress_ratio
)
assert cache_config is not None, "Deepseek V4 indexer requires cache_config"
# NOTE(yifan): FP8 indxer cache use the same layout as V3.2:
# head_dim bytes = 128 fp8 + 4 fp32 scale = 132.
# For FP4 indexer cache, we still allocate the same amount of memory as FP8,
# but only use the first half of the memory.
k_cache_head_dim = self.head_dim + self.head_dim // self.quant_block_size * 4
self.k_cache = DeepseekV4IndexerCache(
head_dim=k_cache_head_dim,
dtype=torch.uint8,
prefix=f"{prefix}.k_cache",
cache_config=cache_config,
compress_ratio=self.compress_ratio,
)
self.compressor = DeepseekCompressor(
vllm_config=vllm_config,
compress_ratio=self.compress_ratio,
hidden_size=hidden_size,
head_dim=self.head_dim,
rotate=True,
prefix=f"{prefix}.compressor",
k_cache_prefix=self.k_cache.prefix,
use_fp4_cache=self.use_fp4_kv,
)
self.indexer_op = SparseAttnIndexer(
self.k_cache,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
skip_k_cache_insert=True,
use_fp4_cache=self.use_fp4_kv,
)
def forward(
self,
hidden_states: torch.Tensor,
qr: torch.Tensor,
positions: torch.Tensor,
rotary_emb: nn.Module,
) -> torch.Tensor:
q, _ = self.wq_b(qr)
q = q.view(-1, self.n_head, self.head_dim)
k = self.compressor(hidden_states, positions, rotary_emb)
weights, _ = self.weights_proj(hidden_states)
q_quant, weights = fused_indexer_q_rope_quant(
positions,
q,
rotary_emb.cos_sin_cache,
weights,
self.softmax_scale,
self.n_head**-0.5,
use_fp4=self.use_fp4_kv,
)
return self.indexer_op(hidden_states, q_quant, k, weights)
......@@ -117,8 +117,10 @@ class RoutingMethodType(IntEnum):
Custom = (6,)
# Simulated
Simulated = (7,)
# Deepseek V4 -> sqrtsoftplus + Bias + Normalize
DeepseekV4 = (8,)
# Unspecified
Unspecified = 8.0
Unspecified = 9.0
def get_routing_method_type(
......@@ -128,6 +130,14 @@ def get_routing_method_type(
num_expert_group: int | None,
has_e_score_bias: bool,
) -> RoutingMethodType:
if scoring_func == "sqrtsoftplus":
# DeepSeek V4 uses sqrtsoftplus routing with optional routing bias
# and top-k renormalization.
if renormalize:
return RoutingMethodType.DeepseekV4
else:
return RoutingMethodType.Unspecified
if has_e_score_bias:
if (num_expert_group or 0) > 0 and scoring_func == "sigmoid":
return RoutingMethodType.DeepSeekV3
......@@ -230,6 +240,13 @@ class FusedMoEQuantConfig:
_w2: FusedMoEQuantDesc
is_nvfp4_scale_swizzled: bool = True
# MXFP4-specific TRTLLM parameters for SwiGLU activation clamping.
# These correspond to gemm1_alpha, gemm1_beta, gemm1_clamp_limit
# in TrtLlmMxfp4ExpertsBase.
gemm1_alpha: float | None = None
gemm1_beta: float | None = None
gemm1_clamp_limit: float | None = None
def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, (
"illegal quantization"
......@@ -477,6 +494,9 @@ class FusedMoEQuantConfig:
w2_zp: torch.Tensor | None = None,
weight_dtype: torch.dtype | str | None = None,
is_nvfp4_scale_swizzled: bool = True,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> "FusedMoEQuantConfig":
"""
General builder function for a FusedMoEQuantConfig.
......@@ -507,6 +527,9 @@ class FusedMoEQuantConfig:
- w1_zp: Optional w1 zero points for int4/int8 quantization.
- w2_zp: Optional w2 zero points for int4/int8 quantization.
- is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling.
- gemm1_alpha: Optional MXFP4 TRTLLM SwiGLU alpha parameter.
- gemm1_beta: Optional MXFP4 TRTLLM SwiGLU beta parameter.
- gemm1_clamp_limit: Optional MXFP4 TRTLLM SwiGLU clamp limit.
"""
assert not isinstance(quant_dtype, str) or quant_dtype in {
"nvfp4",
......@@ -540,6 +563,9 @@ class FusedMoEQuantConfig:
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
),
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
)
assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant
......@@ -650,6 +676,9 @@ def mxfp4_w4a16_moe_quant_config(
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for unquantized activations and mxfp4 weights.
......@@ -659,6 +688,9 @@ def mxfp4_w4a16_moe_quant_config(
_a2=FusedMoEQuantDesc(),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
)
......@@ -670,6 +702,9 @@ def mxfp4_mxfp8_moe_quant_config(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
......@@ -679,6 +714,9 @@ def mxfp4_mxfp8_moe_quant_config(
_a2=FusedMoEQuantDesc("mxfp8"),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
)
......@@ -712,6 +750,9 @@ def ocp_mx_moe_quant_config(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
......@@ -729,6 +770,9 @@ def ocp_mx_moe_quant_config(
per_act_token_quant=False,
per_out_ch_quant=False,
block_shape=block_shape,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
)
......
......@@ -25,15 +25,20 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8_packed_for_deepgemm,
silu_mul_per_token_group_quant_fp8_colmajor,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
silu_mul_quant_fp8_packed_triton as fused_silu_mul_fp8_quant_packed,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kMxfp4Static,
)
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_supported,
m_grouped_fp8_fp4_gemm_nt_contiguous,
m_grouped_fp8_gemm_nt_contiguous,
)
from vllm.utils.import_utils import has_deep_gemm
......@@ -197,8 +202,14 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
M_sum, N = input.size()
activation_out_dim = self.adjust_N_for_activation(N, activation)
# 1. DeepGemm UE8M0: use packed per-token-group quant
# 1. DeepGemm UE8M0: fused SiLU+mul+clamp+quant+pack
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
if activation == MoEActivation.SILU:
return fused_silu_mul_fp8_quant_packed(
input=input,
output_q=output,
group_size=block_k,
)
act_out = torch.empty(
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
......@@ -312,3 +323,225 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
expert_map=expert_map,
output=output,
)
class DeepGemmFP4Experts(mk.FusedMoEExpertsModular):
"""DeepGemm-based fused MoE expert implementation for FP4 weights.
Uses m_grouped_fp8_fp4_gemm_nt_contiguous with FP8 activations and
MXFP4 (FP4 E2M1 packed as uint8) weights. Requires SM100+ (Blackwell).
"""
# FP8 activation block size (hardcoded since mxfp4_w4a8 quant config
# does not set a block_shape on the activation descriptor).
_ACT_BLOCK_K = 128
# FP4 weight block size
_WEIGHT_BLOCK_K = 32
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__(moe_config=moe_config, quant_config=quant_config)
assert quant_config.weight_quant_dtype == "mxfp4"
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
from vllm.platforms import current_platform
return (
is_deep_gemm_supported()
and current_platform.is_device_capability_family(100)
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kMxfp4Static, kFp8Dynamic128Sym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.SWIGLUSTEP]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not (
moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
block_m = get_mk_alignment_for_contiguous_layout()[0]
M_sum = compute_aligned_M(
M, topk, local_num_experts, block_m, expert_tokens_meta
)
assert M_sum % block_m == 0
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M_sum, max(activation_out_dim, K))
workspace2 = (M_sum, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
def _act_mul_quant(
self, input: torch.Tensor, output: torch.Tensor, activation: MoEActivation
) -> tuple[torch.Tensor, torch.Tensor]:
block_k = self._ACT_BLOCK_K
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
M_sum, N = input.size()
activation_out_dim = self.adjust_N_for_activation(N, activation)
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
assert activation == MoEActivation.SILU
return fused_silu_mul_fp8_quant_packed(
input=input,
output_q=output,
group_size=block_k,
clamp_limit=self.gemm1_clamp_limit,
)
if activation == MoEActivation.SILU:
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
return silu_mul_per_token_group_quant_fp8_colmajor(
input=input,
output=output,
use_ue8m0=use_ue8m0,
)
act_out = torch.empty(
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
return per_token_group_quant_fp8(
act_out, block_k, column_major_scales=True, out_q=output
)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert a1q_scale is not None
assert a2_scale is None
assert self.w1_scale is not None
assert self.w2_scale is not None
a1q = hidden_states
_, N, _ = w1.size()
# K comes from activations (full hidden dim), not from w1 which is
# packed FP4 (E, N, K//2).
K = a1q.size(1)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
M_sum = compute_aligned_M(
M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=get_mk_alignment_for_contiguous_layout()[0],
expert_tokens_meta=expert_tokens_meta,
)
a1q_perm = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
)
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1q_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
expert_tokens_meta=expert_tokens_meta,
aq_out=a1q_perm,
)
assert a1q.size(0) == M_sum
# FC1: FP8 activations x FP4 weights
# DeepGEMM 2.4.2 requires FP4-packed weights as int8 (kPackedFP4).
mm1_out = _resize_cache(workspace2, (M_sum, N))
m_grouped_fp8_fp4_gemm_nt_contiguous(
(a1q, a1q_scale),
(w1.view(torch.int8), self.w1_scale),
mm1_out,
expert_ids,
recipe_a=(1, self._ACT_BLOCK_K),
recipe_b=(1, self._WEIGHT_BLOCK_K),
)
# SwiGLU activation + FP8 requant
activation_out_dim = self.adjust_N_for_activation(N, activation)
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
)
a2q, a2q_scale = self._act_mul_quant(
input=mm1_out.view(-1, N), output=quant_out, activation=activation
)
# FC2: FP8 activations x FP4 weights
mm2_out = _resize_cache(workspace2, (M_sum, K))
m_grouped_fp8_fp4_gemm_nt_contiguous(
(a2q, a2q_scale),
(w2.view(torch.int8), self.w2_scale),
mm2_out,
expert_ids,
recipe_a=(1, self._ACT_BLOCK_K),
recipe_b=(1, self._WEIGHT_BLOCK_K),
)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
deepgemm_unpermute_and_reduce(
a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=output,
)
......@@ -28,8 +28,162 @@ from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels
from ..utils import swiglu_limit_func
logger = init_logger(__name__)
def _patch_make_bitmatrix_metadata() -> None:
"""Monkey-patch make_bitmatrix_metadata to support non-power-of-2 top_k.
triton's tl.arange requires a power-of-2 range. The original kernel
computes BLOCK_SIZE = BLOCK_PER_TOK * TOKS_PER_ROW (= 32 * top_k). For
DeepSeek-V4 with top_k=6 this gives 192, which is not a power of 2 and
causes a compile error at the first forward pass.
Fix: define a drop-in replacement kernel that accepts an extra constexpr
BLOCK_SIZE_PADDED (next power of 2 >= BLOCK_SIZE) and uses it for the
tl.arange call while keeping the actual BLOCK_SIZE as the stride between
thread-blocks so that all flat indices into NonzeroIndx stay correct.
Elements beyond BLOCK_SIZE are masked out (col_indx = 0xffff) and ignored.
This function is called once at module load time and patches the function
inside the triton_kernels tensor module so that SparseMatrix.__post_init__
picks up the fixed version transparently.
"""
import torch
import triton
import triton.language as tl
try:
from vllm.third_party.triton_kernels.tensor_details import (
bitmatrix as _bm,
)
from vllm.third_party.triton_kernels.tensor_details.bitmatrix import (
BitmatrixMetadata,
_keyed_add,
cdiv,
)
from vllm.third_party.triton_kernels.tensor_details.bitmatrix_details.sum_bitmatrix_rows import ( # noqa: E501
sum_bitmatrix_rows,
)
except ImportError:
return
@triton.jit
def _stage2_pow2(
ColSortedIndx,
RowSortedIndx,
NonzeroIndx,
n_tokens,
ColPartialSum,
stride_pm,
stride_pn,
ColOffs,
TOKS_PER_ROW: tl.constexpr,
BLOCK_PER_TOK: tl.constexpr,
BLOCK_SIZE_PADDED: tl.constexpr,
):
# Actual number of elements per block (may not be a power of 2).
BLOCK_SIZE: tl.constexpr = BLOCK_PER_TOK * TOKS_PER_ROW
tl.static_assert(BLOCK_SIZE_PADDED <= 32768)
if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr():
n_tokens = tl.load(n_tokens)
nonzero_indx_size = n_tokens * TOKS_PER_ROW
pid_m = tl.program_id(0)
# Use BLOCK_SIZE_PADDED (a power of 2) for tl.arange, but stride by
# the actual BLOCK_SIZE so flat positions in NonzeroIndx are correct.
# Elements with offs_local >= BLOCK_SIZE have offs_global beyond the
# valid range, get col_indx = 0xffff, and are filtered by the mask
# below without producing any output.
offs_local = tl.arange(0, BLOCK_SIZE_PADDED)
offs_global = pid_m * BLOCK_SIZE + offs_local
mask = offs_global < nonzero_indx_size
col_indx = tl.load(NonzeroIndx + offs_global, mask=mask, other=-1).to(tl.uint32)
kv_pairs = ((col_indx << 16) | offs_local).to(tl.uint32)
kv_pairs = tl.sort(kv_pairs, 0)
col_indx = kv_pairs >> 16
offs_global = pid_m * BLOCK_SIZE + (kv_pairs & 0xFFFF)
mask = col_indx != 0xFFFF
x = kv_pairs & 0xFFFF0000 | 0x00000001
cols_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
exclusive_run_lengths = (cols_and_inclusive_run_lengths - 1) & 0xFFFF
row_sorted_indx = tl.load(
ColPartialSum + pid_m * stride_pm + col_indx * stride_pn, mask=mask
)
row_sorted_indx += tl.load(ColOffs + col_indx, mask=mask)
row_sorted_indx += exclusive_run_lengths
tl.store(RowSortedIndx + offs_global, row_sorted_indx, mask=mask)
tl.store(ColSortedIndx + row_sorted_indx, offs_global, mask=mask)
def _make_bitmatrix_metadata_pow2_safe(nonzero_indx, bitmatrix):
assert nonzero_indx.ndim == 2
PARTIAL_BLOCK_M = 32
col_sum, col_partial_sum = sum_bitmatrix_rows(
bitmatrix, partials_block_size=PARTIAL_BLOCK_M
)
device = bitmatrix.device
n_indx = nonzero_indx.numel()
n_cols = bitmatrix.shape[1]
col_offs = torch.empty(n_cols, dtype=torch.int32, device=device)
combined_indx = torch.empty(n_indx * 2, dtype=torch.int32, device=device)
col_sorted_indx = combined_indx[:n_indx]
row_sorted_indx = combined_indx[n_indx:]
MEMSET_BLOCK = 1024
memset_grid = (cdiv(n_indx * 2, MEMSET_BLOCK) + n_cols + 1,)
_bm._bitmatrix_metadata_compute_stage1[memset_grid](
combined_indx,
n_indx * 2,
-1,
MEMSET_BLOCK,
col_sum,
col_offs,
col_sum.shape[0],
col_partial_sum,
col_partial_sum.shape[0],
col_partial_sum.stride(0),
col_partial_sum.stride(1),
BLOCK_M=512,
BLOCK_N=512,
)
toks_per_row = nonzero_indx.shape[-1]
block_size = PARTIAL_BLOCK_M * toks_per_row
# Next power of 2 >= block_size (required by tl.arange).
block_size_padded = 1 << (max(block_size, 1) - 1).bit_length()
compute_grid = (cdiv(bitmatrix.shape_max[0], PARTIAL_BLOCK_M),)
_stage2_pow2[compute_grid](
col_sorted_indx,
row_sorted_indx,
nonzero_indx,
bitmatrix.shape[0],
col_partial_sum,
col_partial_sum.stride(0),
col_partial_sum.stride(1),
col_offs,
TOKS_PER_ROW=toks_per_row,
BLOCK_PER_TOK=PARTIAL_BLOCK_M,
BLOCK_SIZE_PADDED=block_size_padded,
)
return BitmatrixMetadata(
col_sum=col_sum,
col_sorted_indx=col_sorted_indx,
row_sorted_indx=row_sorted_indx,
)
# The most reliable patch point: SparseMatrix.__post_init__ looks up
# make_bitmatrix_metadata via its own __globals__ dict (the tensor.py
# module dict). Patching through __globals__ works regardless of how
# sys.modules maps "triton_kernels.tensor" vs
# "vllm.third_party.triton_kernels.tensor".
from triton_kernels.tensor import SparseMatrix as _SparseMatrix
_SparseMatrix.__post_init__.__globals__["make_bitmatrix_metadata"] = (
_make_bitmatrix_metadata_pow2_safe
)
# Also patch the bitmatrix module itself in case it is imported directly.
_bm.make_bitmatrix_metadata = _make_bitmatrix_metadata_pow2_safe
use_legacy_triton_kernels = False
if has_triton_kernels():
......@@ -59,6 +213,8 @@ if has_triton_kernels():
use_legacy_triton_kernels = True
else:
raise
if not use_legacy_triton_kernels:
_patch_make_bitmatrix_metadata()
except (AttributeError, ImportError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
......@@ -497,6 +653,8 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
if not has_triton_kernels():
return False
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
@staticmethod
......@@ -698,6 +856,37 @@ class UnfusedOAITritonExperts(LoRAExpertsMixin, BaseOAITritonExperts):
def moe_sum(self, input: torch.Tensor, output: torch.Tensor):
ops.moe_sum(input, output)
def activation(
self,
activation: MoEActivation,
output: torch.Tensor,
input: torch.Tensor,
) -> None:
quant_config = self.quant_config or FUSED_MOE_UNQUANTIZED_CONFIG
if activation == MoEActivation.SWIGLUOAI:
alpha = (
quant_config.gemm1_alpha
if quant_config.gemm1_alpha is not None
else 1.702
)
limit = (
quant_config.gemm1_clamp_limit
if quant_config.gemm1_clamp_limit is not None
else 7.0
)
torch.ops._C.swigluoai_and_mul(output, input, alpha, limit)
elif (
activation == MoEActivation.SILU
and quant_config.gemm1_clamp_limit is not None
):
swiglu_limit_func(
output,
input,
quant_config.gemm1_clamp_limit,
)
else:
super().activation(activation, output, input)
def apply(
self,
output: torch.Tensor,
......@@ -812,9 +1001,9 @@ class UnfusedOAITritonExperts(LoRAExpertsMixin, BaseOAITritonExperts):
act_input,
)
# matmul_ogs grouped reduction fuse sum across multiple experts:
# matmul_ogs grouped reduction fuses sum across multiple experts:
# y[dst_indx // n_expts_act, :] += x
# Need to set n_expts_act to 1 to unfuse moe_sum
# Set n_expts_act to 1 to unfuse the sum so we can do it manually via moe_sum.
routing_data.n_expts_act = 1
matmul_ogs(
......@@ -878,6 +1067,8 @@ class OAITritonMxfp4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
if not has_triton_kernels():
return False
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
@staticmethod
......
......@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kMxfp4Static,
......@@ -32,10 +33,8 @@ class TrtLlmMxfp4ExpertsBase:
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
**kwargs,
):
# NOTE: FusedMoEExperts.__init__ is called by the concrete subclass
# (Monolithic/Modular) via MRO, not here, to avoid mypy issues with
# multiple inheritance. This matches the NvFP4 expert pattern.
self.moe_config = moe_config
self.quant_config = quant_config
......@@ -48,23 +47,34 @@ class TrtLlmMxfp4ExpertsBase:
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
# MXFP4-specific TRTLLM parameters
# MXFP4-specific TRTLLM parameters from quant_config
device = torch.accelerator.current_device_index()
self.gemm1_alpha = torch.tensor(
[1.702] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
self.gemm1_beta = torch.tensor(
[1.0] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
self.gemm1_clamp_limit = torch.tensor(
[7.0] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
if quant_config.gemm1_alpha is not None:
self.gemm1_alpha = torch.tensor(
[quant_config.gemm1_alpha] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
else:
self.gemm1_alpha = None
if quant_config.gemm1_beta is not None:
self.gemm1_beta = torch.tensor(
[quant_config.gemm1_beta] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
else:
self.gemm1_beta = None
if quant_config.gemm1_clamp_limit is not None:
self.gemm1_clamp_limit = torch.tensor(
[quant_config.gemm1_clamp_limit] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
else:
self.gemm1_clamp_limit = None
from vllm.config import get_current_vllm_config
......@@ -97,7 +107,7 @@ class TrtLlmMxfp4ExpertsBase:
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SWIGLUOAI
return activation in (MoEActivation.SWIGLUOAI, MoEActivation.SILU)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
......@@ -190,36 +200,41 @@ class TrtLlmMxfp4ExpertsMonolithic(
output = torch.empty_like(hidden_states)
return trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
hidden_states=x_quant,
hidden_states_scale=x_scale,
gemm1_weights=w1,
gemm1_weights_scale=self.w1_scale,
gemm1_bias=self.w1_bias,
gemm1_alpha=self.gemm1_alpha,
gemm1_beta=self.gemm1_beta,
gemm1_clamp_limit=self.gemm1_clamp_limit,
gemm2_weights=w2,
gemm2_weights_scale=self.w2_scale,
gemm2_bias=self.w2_bias,
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=global_num_experts,
top_k=self.topk,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=self.routing_method_type,
do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1),
output=output,
)[0]
from vllm.utils.flashinfer import _is_fi_autotuning, autotune
with autotune(_is_fi_autotuning):
trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
hidden_states=x_quant,
hidden_states_scale=x_scale,
gemm1_weights=w1,
gemm1_weights_scale=self.w1_scale,
gemm1_bias=self.w1_bias,
gemm1_alpha=self.gemm1_alpha,
gemm1_beta=self.gemm1_beta,
gemm1_clamp_limit=self.gemm1_clamp_limit,
gemm2_weights=w2,
gemm2_weights_scale=self.w2_scale,
gemm2_bias=self.w2_bias,
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=global_num_experts,
top_k=self.topk,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=self.routing_method_type,
do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1),
output=output,
)
return output
class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModular):
......@@ -239,6 +254,16 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
) -> bool:
return True
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# Modular kernel handles only the expert computation;
# routing is done externally, so accept any routing method.
return True
def supports_expert_map(self) -> bool:
return True
......@@ -282,7 +307,7 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
):
topk = topk_ids.size(-1)
local_num_experts = w1.size(0)
intermediate_size = w2.size(1)
intermediate_size = self.intermediate_size_per_partition
local_expert_offset = self.moe_config.ep_rank * local_num_experts
# Handle input quantization
......@@ -302,9 +327,8 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
x_quant = hidden_states
x_scale = None
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
).view(torch.int16)
# Pack topk ids and weights into format expected by the kernel.
packed_tensor = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
assert self.w1_scale is not None
assert self.w2_scale is not None
......@@ -333,7 +357,10 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
"local_expert_offset": local_expert_offset,
"local_num_experts": local_num_experts,
"routed_scaling_factor": None,
"routing_method_type": self.routing_method_type,
# Modular kernel receives pre-routed tokens, so routing
# is already done. Use Renormalize as a safe default that
# the TRTLLM C++ kernel supports.
"routing_method_type": RoutingMethodType.Renormalize,
"do_finalize": True,
"output": output,
"tune_max_num_tokens": max(self.max_capture_size, 1),
......@@ -341,12 +368,9 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
from flashinfer import trtllm_fp4_block_scale_routed_moe
from vllm.utils.flashinfer import autotune
from vllm.utils.flashinfer import _is_fi_autotuning, autotune
with autotune(False):
# Enable autotune when,
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is
# resolved.
with autotune(_is_fi_autotuning):
trtllm_fp4_block_scale_routed_moe(**kwargs)
return output
......@@ -50,6 +50,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
from .utils import swiglu_limit_func
def _fused_marlin_moe(
hidden_states: torch.Tensor,
......@@ -88,6 +90,7 @@ def _fused_marlin_moe(
output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
is_k_full: bool = True,
clamp_limit: float | None = None,
) -> torch.Tensor:
assert hidden_states.ndim == 2
M, K = hidden_states.size()
......@@ -155,11 +158,18 @@ def _fused_marlin_moe(
use_fp32_reduce=True,
is_zp_float=False,
)
activation_func(
activation,
intermediate_cache2,
intermediate_cache1.view(-1, w13_num_shards * N),
)
if clamp_limit is not None and activation == MoEActivation.SILU:
swiglu_limit_func(
intermediate_cache2,
intermediate_cache1.view(-1, w13_num_shards * N),
clamp_limit,
)
else:
activation_func(
activation,
intermediate_cache2,
intermediate_cache1.view(-1, w13_num_shards * N),
)
if output is None:
output = intermediate_cache3
......@@ -247,6 +257,7 @@ def fused_marlin_moe(
output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
inplace: bool = False,
clamp_limit: float | None = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......@@ -363,6 +374,7 @@ def fused_marlin_moe(
output=None,
input_dtype=input_dtype,
is_k_full=is_k_full,
clamp_limit=clamp_limit,
).view(-1, topk, K)
if output is None:
......@@ -557,6 +569,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
self.is_k_full = is_k_full
self.input_dtype = get_marlin_input_dtype()
self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit
super().__init__(
moe_config=moe_config,
......@@ -850,6 +863,7 @@ class MarlinExperts(LoRAExpertsMixin, MarlinExpertsBase):
sort_indices2=self.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
clamp_limit=self.gemm1_clamp_limit,
)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
......
......@@ -169,5 +169,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
......@@ -268,6 +268,7 @@ class FusedMoE(PluggableLayer):
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
swiglu_limit: float | None = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -285,6 +286,7 @@ class FusedMoE(PluggableLayer):
routed_output_transform: torch.nn.Module | None = None,
apply_routed_scale_to_output: bool = False,
zero_expert_type: str | None = None,
hash_indices_table: torch.Tensor | None = None,
):
super().__init__()
......@@ -294,6 +296,7 @@ class FusedMoE(PluggableLayer):
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
self.swiglu_limit = swiglu_limit
# FIXME (varun): We should have a better way of inferring the activation
# datatype. This works for now as the tensor datatype entering the MoE
......@@ -455,6 +458,7 @@ class FusedMoE(PluggableLayer):
self.e_score_correction_bias = e_score_correction_bias
# TODO(bnell): end attributes
self.hash_indices_table = hash_indices_table
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = MoEActivation.from_str(activation)
......@@ -479,6 +483,7 @@ class FusedMoE(PluggableLayer):
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
zero_expert_type=zero_expert_type,
num_logical_experts=self.logical_num_experts,
hash_indices_table=self.hash_indices_table,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
......@@ -1541,10 +1546,12 @@ class FusedMoE(PluggableLayer):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
return self.runner.forward(
hidden_states,
router_logits,
input_ids,
)
@property
......
......@@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
FusedMoEQuantDesc,
mxfp4_mxfp8_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
......@@ -24,6 +25,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kMxfp4Static,
kMxfp8Dynamic,
)
......@@ -46,6 +48,8 @@ if has_triton_kernels():
class Mxfp4MoeBackend(Enum):
NONE = "None"
# DeepGEMM FP8xFP4 backend (SM100+)
DEEPGEMM_MXFP4 = "DEEPGEMM_MXFP4"
# FlashInfer TRTLLM backends
FLASHINFER_TRTLLM_MXFP4_MXFP8 = "FLASHINFER_TRTLLM_MXFP4_MXFP8"
FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_TRTLLM_MXFP4_BF16"
......@@ -81,7 +85,14 @@ TRITON_BACKENDS = (
def backend_to_kernel_cls(
backend: Mxfp4MoeBackend,
) -> list[type[mk.FusedMoEExperts]]:
if backend in (
if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
from vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe import (
DeepGemmFP4Experts,
)
return [DeepGemmFP4Experts]
elif backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
):
......@@ -159,11 +170,13 @@ def backend_to_kernel_cls(
def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
"""Map user's moe_backend string to Mxfp4MoeBackend."""
mapping: dict[str, Mxfp4MoeBackend] = {
"deep_gemm": Mxfp4MoeBackend.DEEPGEMM_MXFP4,
"flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
"flashinfer_trtllm_afp8": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
"flashinfer_cutlass": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
"flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
"triton": Mxfp4MoeBackend.TRITON,
"triton_unfused": Mxfp4MoeBackend.TRITON_UNFUSED,
"marlin": Mxfp4MoeBackend.MARLIN,
"aiter": Mxfp4MoeBackend.AITER,
"xpu": Mxfp4MoeBackend.XPU,
......@@ -177,7 +190,7 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
)
def _get_priority_backends() -> list[Mxfp4MoeBackend]:
def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]:
"""
Get available backends in priority order based on platform and config.
Only includes BF16 backends. MXFP8 backends are selected via env vars.
......@@ -187,7 +200,9 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
Mxfp4MoeBackend.AITER,
Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.TRITON_UNFUSED,
# TRITON_UNFUSED has bug with MTP support
# TODO re-enable after kernel is fixed
# TRITON_UNFUSED
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
Mxfp4MoeBackend.XPU,
......@@ -196,8 +211,28 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
return _AVAILABLE_BACKENDS
def _get_priority_backends() -> list[Mxfp4MoeBackend]:
"""
Get available backends in priority order. SM100+ prefers DeepGEMM FP4 /
TRTLLM MXFP8; SM90 falls through to Triton_unfused or Marlin (the
backend-level ``is_supported_config`` check filters by device capability).
"""
_AVAILABLE_BACKENDS = [
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.DEEPGEMM_MXFP4,
# TRITON_UNFUSED has bug with MTP support
# TODO re-enable after kernel is fixed
# TRITON_UNFUSED
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
]
return _AVAILABLE_BACKENDS
def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
"""Map backend to its activation key (MXFP8 or None for BF16)."""
"""Map backend to its activation key (FP8, MXFP8, or None for BF16)."""
if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
return kFp8Dynamic128Sym
if backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
......@@ -290,7 +325,7 @@ def select_gpt_oss_mxfp4_moe_backend(
)
# Select kernels in order of backend.
AVAILABLE_BACKENDS = _get_priority_backends()
AVAILABLE_BACKENDS = _get_priority_backends_for_gpt_oss()
# Handle explicit FlashInfer MXFP4 BF16 configuration.
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
......@@ -387,11 +422,95 @@ def select_gpt_oss_mxfp4_moe_backend(
return Mxfp4MoeBackend.NONE, None
def select_mxfp4_moe_backend(
config: FusedMoEConfig,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]:
"""
Select the MXFP4 MoE backend with MXFP8 activation as top priority.
Falls back through BF16 and other backends.
"""
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
def _make_log_backend(backend: Mxfp4MoeBackend):
return f"Using '{backend.value}' Mxfp4 MoE backend."
def _make_log_unsupported(backend: Mxfp4MoeBackend, reason: str | None) -> str:
if reason:
return (
f"Mxfp4 MoE backend '{backend.value}' does not support the "
f"deployment configuration since {reason}."
)
return (
f"Mxfp4 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: Mxfp4MoeBackend,
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts]]:
reason: str | None = None
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Honor explicit moe_backend (e.g. "marlin", "triton_unfused") before
# falling back to the auto priority list.
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_mxfp4_backend(runner_backend)
if (
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
and requested_backend == Mxfp4MoeBackend.MARLIN
):
requested_backend = Mxfp4MoeBackend.BATCHED_MARLIN
return _return_or_raise(
requested_backend,
config,
kMxfp4Static,
_backend_activation_key(requested_backend),
activation_format,
)
# Iterate priority backends: TRTLLM MXFP8, then Triton.
for backend in _get_priority_backends():
activation_key = _backend_activation_key(backend)
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, kMxfp4Static, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No MXFP4 MoE backend supports the deployment configuration."
)
def mxfp4_round_up_hidden_size_and_intermediate_size(
backend: Mxfp4MoeBackend, hidden_size: int, intermediate_size: int
) -> tuple[int, int]:
"""Round up hidden_size and intermediate_size based on backend requirements."""
if backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
# DeepGEMM requires M/N/K alignment
intermediate_size = round_up(intermediate_size, 128)
hidden_size = round_up(hidden_size, 128)
elif backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
intermediate_size = round_up(intermediate_size, 128)
if current_platform.is_xpu():
hidden_size = round_up(hidden_size, 128)
......@@ -434,6 +553,20 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
]:
"""Convert loaded weights into backend-specific kernel format."""
if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_upcast_e8m0_to_fp32,
)
return (
w13_weight.data,
w2_weight.data,
_upcast_e8m0_to_fp32(w13_weight_scale.data),
_upcast_e8m0_to_fp32(w2_weight_scale.data),
w13_bias,
w2_bias,
)
num_experts = w13_weight.shape[0]
intermediate_size = w13_weight.shape[1] // 2
hidden_size = w13_weight.shape[2] * 2
......@@ -738,9 +871,10 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
elif mxfp4_backend in TRITON_BACKENDS:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
assert w13_bias is not None and w2_bias is not None
w13_bias = w13_bias.to(torch.float32)
w2_bias = w2_bias.to(torch.float32)
if w13_bias is not None:
w13_bias = w13_bias.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.to(torch.float32)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
w13_weight,
......@@ -797,15 +931,271 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
)
def convert_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend: Mxfp4MoeBackend,
layer: torch.nn.Module,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
_cache_permute_indices: dict[torch.Size, torch.Tensor] | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor,
Union[torch.Tensor, "PrecisionConfig"],
Union[torch.Tensor, "PrecisionConfig"],
torch.Tensor | None,
torch.Tensor | None,
]:
"""Convert loaded weights into backend-specific kernel format.
Supports DeepGEMM, TRTLLM MXFP8, Triton and Marlin backends.
"""
if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_upcast_e8m0_to_fp32,
)
# Weights stay as uint8 packed FP4 — no layout change needed.
# Convert E8M0 uint8 scales to float32.
return (
w13_weight.data,
w2_weight.data,
_upcast_e8m0_to_fp32(w13_weight_scale.data),
_upcast_e8m0_to_fp32(w2_weight_scale.data),
w13_bias,
w2_bias,
)
if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_mxfp4_layer_for_marlin,
)
return prepare_moe_mxfp4_layer_for_marlin(
layer,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
num_experts = w13_weight.shape[0]
intermediate_size = w13_weight.shape[1] // 2
hidden_size = w13_weight.shape[2] * 2
sf_block_size = 32 # mxfp4 block size
if mxfp4_backend in TRTLLM_BACKENDS:
assert _cache_permute_indices is not None
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
w13_weight = w13_weight.data
w2_weight = w2_weight.data
w13_weight_scale = w13_weight_scale.data
w2_weight_scale = w2_weight_scale.data
if w13_bias is not None:
w13_bias = w13_bias.data.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.data.to(torch.float32)
# Swap w1/w3 and interleave to match TRTLLM SwiGLU convention.
# Standard loading gives contiguous [w1/gate, w3/up].
# TRTLLM kernel expects interleaved [w3_0, w1_0, w3_1, w1_1, ...].
w1_weight = w13_weight[:, :intermediate_size, :]
w3_weight = w13_weight[:, intermediate_size:, :]
w13_weight = torch.stack([w3_weight, w1_weight], dim=2).reshape(
w13_weight.shape
)
w1_scale = w13_weight_scale[:, :intermediate_size, :]
w3_scale = w13_weight_scale[:, intermediate_size:, :]
w13_weight_scale = torch.stack([w3_scale, w1_scale], dim=2).reshape(
w13_weight_scale.shape
)
if w13_bias is not None:
b1 = w13_bias[:, :intermediate_size]
b3 = w13_bias[:, intermediate_size:]
w13_bias = torch.stack([b3, b1], dim=2).reshape(w13_bias.shape)
# Shuffle weights and scaling factors for transposed mma output.
# Permute indices depend only on shape (cached by torch.Size),
# so compute once and apply to all experts via batched indexing.
epilogue_tile_m = 128
# w13 weight permute
w13_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_weight[0].view(torch.uint8),
epilogue_tile_m,
).to(w13_weight.device)
w13_weight = w13_weight.view(torch.uint8)[:, w13_perm].contiguous()
# w13 scale permute + interleave
w13_sf_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_weight_scale[0].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
).to(w13_weight_scale.device)
w13_s = w13_weight_scale.view(torch.uint8)[:, w13_sf_perm].contiguous()
E, N_s, K_s = w13_s.shape
w13_weight_scale = (
nvfp4_block_scale_interleave(w13_s.reshape(E * N_s, K_s))
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
# w2 weight permute
w2_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_weight[0].view(torch.uint8),
epilogue_tile_m,
).to(w2_weight.device)
w2_weight = w2_weight.view(torch.uint8)[:, w2_perm].contiguous()
# w2 scale permute + interleave
w2_sf_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_weight_scale[0].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
).to(w2_weight_scale.device)
w2_s = w2_weight_scale.view(torch.uint8)[:, w2_sf_perm].contiguous()
E2, N2_s, K2_s = w2_s.shape
w2_weight_scale = (
nvfp4_block_scale_interleave(w2_s.reshape(E2 * N2_s, K2_s))
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
# w13 bias permute
if w13_bias is not None:
w13_b_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_bias[0].reshape(-1, 1),
epilogue_tile_m,
).to(w13_bias.device)
w13_bias = w13_bias.reshape(num_experts, -1, 1)[:, w13_b_perm].reshape(
num_experts, -1
)
# w2 bias permute
if w2_bias is not None:
w2_b_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_bias[0].reshape(-1, 1),
epilogue_tile_m,
).to(w2_bias.device)
w2_bias = w2_bias.reshape(num_experts, -1, 1)[:, w2_b_perm].reshape(
num_experts, -1
)
return (
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
elif mxfp4_backend in TRITON_BACKENDS:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
if mxfp4_backend == Mxfp4MoeBackend.TRITON:
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
shape = w.shape
n = shape[-1]
first = w[..., : n // 2]
second = w[..., n // 2 :]
stacked = torch.stack((first, second), dim=-1)
return stacked.reshape(shape)
w13_weight = shuffle_weight(w13_weight)
w13_weight_scale = shuffle_weight(w13_weight_scale)
if w13_bias is not None:
w13_bias = shuffle_weight(w13_bias.to(torch.float32))
else:
if w13_bias is not None:
w13_bias = w13_bias.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.to(torch.float32)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
w13_weight,
w13_weight_scale,
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
w2_weight,
w2_weight_scale,
)
w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
del layer.w13_weight
del layer.w2_weight
return (
w13_weight,
w2_weight,
w13_precision_config,
w2_precision_config,
w13_bias,
w2_bias,
)
else:
raise ValueError(
f"Unsupported mxfp4_backend for Mxfp4MoEMethod: {mxfp4_backend}. "
f"Expected TRTLLM or Triton backend."
)
def make_mxfp4_moe_quant_config(
mxfp4_backend: Mxfp4MoeBackend,
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
swiglu_limit: float | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig | None:
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
if mxfp4_backend in (
if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
# DeepGEMM FP4 uses FP8 per-token-group activation quantization
# with block 128, matching the FP8 DeepGEMM path.
_fp8_dtype = current_platform.fp8_dtype()
_block_shape = GroupShape(128, 128)
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
_a2=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
elif mxfp4_backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
......@@ -814,6 +1204,9 @@ def make_mxfp4_moe_quant_config(
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
elif mxfp4_backend in (
Mxfp4MoeBackend.MARLIN,
......@@ -829,6 +1222,9 @@ def make_mxfp4_moe_quant_config(
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
else:
return ocp_mx_moe_quant_config(
......@@ -837,6 +1233,9 @@ def make_mxfp4_moe_quant_config(
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
......
......@@ -228,6 +228,8 @@ class BaseRouter(FusedMoERouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute the actual routing logic.
......@@ -249,6 +251,8 @@ class BaseRouter(FusedMoERouter):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
......@@ -278,7 +282,7 @@ class BaseRouter(FusedMoERouter):
# Step 3: Compute routing (delegated to subclass)
topk_weights, topk_ids = self._compute_routing(
hidden_states, router_logits, indices_type
hidden_states, router_logits, indices_type, input_ids=input_ids
)
# Capture logical ids before EPLB mapping.
......
......@@ -46,6 +46,8 @@ class CustomRoutingRouter(BaseRouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using the custom routing function."""
topk_weights, topk_ids = self.custom_routing_function(
......
......@@ -31,6 +31,8 @@ class FusedMoERouter(ABC):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
......
......@@ -4,6 +4,7 @@ import functools
from collections.abc import Callable
import torch
import torch.nn.functional as F
import vllm._custom_ops as ops
import vllm.envs as envs
......@@ -56,6 +57,32 @@ def vllm_topk_sigmoid(
return topk_weights, topk_indices
def vllm_topk_softplus_sqrt(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
input_tokens: torch.Tensor | None = None,
hash_indices_table: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, ...]:
ops.topk_hash_softplus_sqrt(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
input_tokens,
hash_indices_table,
)
return topk_weights, topk_indices
@functools.lru_cache(maxsize=8)
def _aiter_get_num_expert_group(num_experts: int) -> int:
_AITER_MAX_EXPERTS_PER_GROUP = 32
......@@ -72,11 +99,14 @@ def _aiter_get_num_expert_group(num_experts: int) -> int:
def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
scoring_func: str,
e_score_correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
scoring_func: str = "softmax",
indices_type: torch.dtype | None = None,
input_tokens: torch.Tensor | None = None,
hash_indices_table: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
):
if not rocm_aiter_ops.is_fused_moe_enabled():
assert hidden_states.size(0) == gating_output.size(0), (
......@@ -107,6 +137,8 @@ def fused_topk_bias(
renormalize,
e_score_correction_bias,
)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights, topk_ids
elif scoring_func == "sigmoid":
topk_weights, topk_ids = vllm_topk_sigmoid(
......@@ -117,9 +149,24 @@ def fused_topk_bias(
renormalize,
e_score_correction_bias,
)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights, topk_ids
elif scoring_func == "sqrtsoftplus":
return vllm_topk_softplus_sqrt(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
input_tokens,
hash_indices_table,
routed_scaling_factor,
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
elif rocm_aiter_ops.is_fused_moe_enabled() and scoring_func == "sigmoid":
M = hidden_states.size(0)
num_experts = gating_output.shape[-1]
......@@ -143,6 +190,8 @@ def fused_topk_bias(
topk_group=num_expert_group,
need_renorm=renormalize,
)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights, topk_ids
n_routed_experts = gating_output.shape[-1]
......@@ -150,20 +199,31 @@ def fused_topk_bias(
scores = gating_output.softmax(dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
elif scoring_func == "sqrtsoftplus":
scores = F.softplus(gating_output).sqrt()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
if e_score_correction_bias is not None:
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
else:
scores_for_choice = scores.view(-1, n_routed_experts)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = envs.VLLM_BATCH_INVARIANT
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
if hash_indices_table is not None:
topk_indices = hash_indices_table[input_tokens]
else:
use_sorted = envs.VLLM_BATCH_INVARIANT
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[
1
]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to(
topk_weights = topk_weights.to(torch.float32)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights, topk_indices.to(
torch.int32 if indices_type is None else indices_type
)
......@@ -176,12 +236,14 @@ class FusedTopKBiasRouter(BaseRouter):
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
e_score_correction_bias: torch.Tensor,
scoring_func: str,
e_score_correction_bias: torch.Tensor | None = None,
renormalize: bool = True,
routed_scaling_factor: float = 1.0,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
*,
scoring_func: str = "sigmoid",
hash_indices_table: torch.Tensor | None = None,
):
super().__init__(
top_k=top_k,
......@@ -194,6 +256,8 @@ class FusedTopKBiasRouter(BaseRouter):
self.renormalize = renormalize
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.scoring_func = scoring_func
self._hash_indices_table = hash_indices_table
@property
def routing_method_type(self) -> RoutingMethodType:
......@@ -210,19 +274,23 @@ class FusedTopKBiasRouter(BaseRouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using fused top-k with bias."""
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias.data,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias.data
if self.e_score_correction_bias is not None
else None,
topk=self.top_k,
renormalize=self.renormalize,
scoring_func=self.scoring_func,
indices_type=indices_type,
input_tokens=input_ids,
hash_indices_table=self._hash_indices_table,
routed_scaling_factor=self.routed_scaling_factor,
)
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
return topk_weights, topk_ids
......@@ -151,6 +151,8 @@ class FusedTopKRouter(BaseRouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using standard fused top-k."""
topk_weights, topk_ids, token_expert_indices = fused_topk(
......
......@@ -292,6 +292,8 @@ class GroupedTopKRouter(BaseRouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using grouped top-k."""
......@@ -308,6 +310,7 @@ class GroupedTopKRouter(BaseRouter):
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
......
......@@ -55,6 +55,7 @@ def create_fused_moe_router(
# zero expert parameters
zero_expert_type: str | None = None,
num_logical_experts: int | None = None,
hash_indices_table: torch.Tensor | None = None,
) -> FusedMoERouter:
"""
Factory function to create the appropriate FusedMoERouter subclass based on
......@@ -99,6 +100,9 @@ def create_fused_moe_router(
num_logical_experts: Number of real (non-zero) experts. Required when
zero_expert_type is not None.
Hash Indices Table:
Used to map input_ids to experts, need for Deepseek V4
Returns:
An instance of the appropriate FusedMoERouter subclass
"""
......@@ -179,17 +183,20 @@ def create_fused_moe_router(
indices_type_getter=indices_type_getter,
)
if e_score_correction_bias is not None:
assert scoring_func in ["sigmoid", "softmax", "sqrtsoftplus"]
if e_score_correction_bias is not None or hash_indices_table is not None:
return FusedTopKBiasRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
e_score_correction_bias=e_score_correction_bias,
scoring_func=scoring_func,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
scoring_func=scoring_func,
hash_indices_table=hash_indices_table,
)
return FusedTopKRouter(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment