Unverified Commit 4d51588e authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Feat] DeepSeek V4 Rebased (#40860)


Signed-off-by: default avatarYifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarqizixi <zixi@inferact.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <yongye@inferact.ai>
Co-authored-by: default avatarSimon Mo <simon@inferact.ai>
Co-authored-by: default avatarBugen Zhao <i@bugenzhao.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoy Wang <yasong.wang@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarZhewen Li <jerven.vllm@gmail.com>
Co-authored-by: default avatarZijing Liu <liuzijing2014@gmail.com>
Co-authored-by: default avatarkhluu <khluu000@gmail.com>
Co-authored-by: default avatarqizixi <zixi@inferact.ai>
Co-authored-by: default avatarZhewen Li <zhewenli@inferact.ai>
parent 32e45636
......@@ -125,12 +125,16 @@ def _missing(*_: Any, **__: Any) -> NoReturn:
)
_cublaslt_gemm_nt_impl: Callable[..., Any] | None = None
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_fp8_einsum_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
_grouped_fp4_impl: Callable[..., Any] | None = None
_fp8_fp4_mqa_logits_impl: Callable[..., Any] | None = None
_fp8_fp4_paged_mqa_logits_impl: Callable[..., Any] | None = None
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
_tf32_hc_prenorm_gemm_impl: Callable[..., Any] | None = None
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
......@@ -173,20 +177,27 @@ def _import_deep_gemm():
def _lazy_init() -> None:
"""Import deep_gemm and resolve symbols on first use."""
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
global _cublaslt_gemm_nt_impl
global _fp8_gemm_nt_impl, _fp8_einsum_impl
global _grouped_impl, _grouped_masked_impl, _grouped_fp4_impl
global _fp8_fp4_mqa_logits_impl, _fp8_fp4_paged_mqa_logits_impl
global _get_paged_mqa_logits_metadata_impl
global _tf32_hc_prenorm_gemm_impl
global _get_mn_major_tma_aligned_tensor_impl
global _get_mk_alignment_for_contiguous_layout_impl
global _transform_sf_into_required_layout_impl
# fast path
if (
_fp8_gemm_nt_impl is not None
_cublaslt_gemm_nt_impl is not None
or _fp8_gemm_nt_impl is not None
or _fp8_einsum_impl is not None
or _grouped_impl is not None
or _grouped_masked_impl is not None
or _fp8_mqa_logits_impl is not None
or _fp8_paged_mqa_logits_impl is not None
or _grouped_fp4_impl is not None
or _fp8_fp4_mqa_logits_impl is not None
or _fp8_fp4_paged_mqa_logits_impl is not None
or _get_paged_mqa_logits_metadata_impl is not None
or _tf32_hc_prenorm_gemm_impl is not None
or _get_mk_alignment_for_contiguous_layout_impl is not None
or _transform_sf_into_required_layout_impl is not None
):
......@@ -206,14 +217,20 @@ def _lazy_init() -> None:
if _dg is None:
return
_cublaslt_gemm_nt_impl = getattr(_dg, "cublaslt_gemm_nt", None)
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
_fp8_einsum_impl = getattr(_dg, "fp8_einsum", None)
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
_grouped_fp4_impl = getattr(_dg, "m_grouped_fp8_fp4_gemm_nt_contiguous", None)
# DeepGEMM exposes fp8_fp4_*_mqa_logits as the canonical symbols that
# handle both the FP8 and FP4 Q/K paths via a tuple-typed `q`.
_fp8_fp4_mqa_logits_impl = getattr(_dg, "fp8_fp4_mqa_logits", None)
_fp8_fp4_paged_mqa_logits_impl = getattr(_dg, "fp8_fp4_paged_mqa_logits", None)
_get_paged_mqa_logits_metadata_impl = getattr(
_dg, "get_paged_mqa_logits_metadata", None
)
_tf32_hc_prenorm_gemm_impl = getattr(_dg, "tf32_hc_prenorm_gemm", None)
_get_mn_major_tma_aligned_tensor_impl = getattr(
_dg, "get_mn_major_tma_aligned_tensor", None
)
......@@ -259,6 +276,13 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
return _get_mn_major_tma_aligned_tensor_impl(x)
def cublaslt_gemm_nt(*args, **kwargs):
_lazy_init()
if _cublaslt_gemm_nt_impl is None:
return _missing(*args, **kwargs)
return _cublaslt_gemm_nt_impl(*args, **kwargs)
def fp8_gemm_nt(*args, **kwargs):
_lazy_init()
if _fp8_gemm_nt_impl is None:
......@@ -271,6 +295,13 @@ def fp8_gemm_nt(*args, **kwargs):
return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)
def fp8_einsum(*args, **kwargs):
_lazy_init()
if _fp8_einsum_impl is None:
return _missing(*args, **kwargs)
return _fp8_einsum_impl(*args, **kwargs)
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
_lazy_init()
if _grouped_impl is None:
......@@ -280,6 +311,15 @@ def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
)
def m_grouped_fp8_fp4_gemm_nt_contiguous(*args, **kwargs):
_lazy_init()
if _grouped_fp4_impl is None:
return _missing(*args, **kwargs)
return _grouped_fp4_impl(
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
)
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
_lazy_init()
if _grouped_masked_impl is None:
......@@ -298,37 +338,48 @@ def transform_sf_into_required_layout(*args, **kwargs):
)
def fp8_mqa_logits(
q: torch.Tensor,
def fp8_fp4_mqa_logits(
q: tuple[torch.Tensor, torch.Tensor | None],
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
clean_logits: bool,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
"""Compute MQA logits for a single sequence without KV paging.
Unified FP8/FP4 dispatch — the underlying DeepGEMM kernel takes
``q = (values, scales_or_None)`` where ``scales`` is None for FP8 Q
(per-token scale is folded into ``weights``) and a packed block-scale
tensor for MXFP4 Q.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N])
with dtype `torch.float32`.
q: Tuple ``(q_values, q_scale)``. FP8 path: q_values is [M, H, D]
float8_e4m3fn and q_scale is None (per-token scale is folded
into ``weights``). FP4 path: q_values is packed uint8 and
q_scale is the companion block-scale tensor.
kv: Tuple `(k_packed, k_scales)` — FP8 layout is [N, D]
float8_e4m3fn plus fp32 scales [N]; FP4 layout is packed uint8.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ks: Start indices (inclusive) for valid K per query
position, shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query
position, shape [M], dtype int32.
clean_logits: Whether to clean the unfilled logits into `-inf`.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
_lazy_init()
if _fp8_mqa_logits_impl is None:
if _fp8_fp4_mqa_logits_impl is None:
return _missing()
return _fp8_mqa_logits_impl(
q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=clean_logits
return _fp8_fp4_mqa_logits_impl(
q,
kv,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
clean_logits=clean_logits,
)
......@@ -344,7 +395,7 @@ def get_paged_mqa_logits_metadata(
num_sms: Number of SMs available. 132 for Hopper
Returns:
Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
Backend-specific tensor consumed by `fp8_fp4_paged_mqa_logits` to
schedule work across SMs.
"""
_lazy_init()
......@@ -353,9 +404,9 @@ def get_paged_mqa_logits_metadata(
return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms)
def fp8_paged_mqa_logits(
q_fp8: torch.Tensor,
kv_cache_fp8: torch.Tensor,
def fp8_fp4_paged_mqa_logits(
q: tuple[torch.Tensor, torch.Tensor | None],
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
......@@ -363,14 +414,20 @@ def fp8_paged_mqa_logits(
max_model_len: int,
clean_logits: bool,
) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache.
"""Compute MQA logits using a paged KV-cache.
Unified FP8/FP4 dispatch — the underlying DeepGEMM kernel takes
``q = (values, scales_or_None)``; pass ``(q_tensor, None)`` for the FP8
path and ``(q_values, q_scale)`` for MXFP4.
Args:
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
4 bytes per (block,pos) store the `float` dequant scale.
q: Tuple ``(q_values, q_scale)``. FP8 path: q_values is
[B, next_n, H, D] float8_e4m3fn and q_scale is None. FP4 path:
q_values is packed uint8 and q_scale is the companion
block-scale tensor.
kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, 1,
D+4], dtype `torch.uint8`, with the last 4 bytes per (block, pos)
storing the float dequant scale.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
......@@ -386,11 +443,11 @@ def fp8_paged_mqa_logits(
`torch.float32`.
"""
_lazy_init()
if _fp8_paged_mqa_logits_impl is None:
if _fp8_fp4_paged_mqa_logits_impl is None:
return _missing()
return _fp8_paged_mqa_logits_impl(
q_fp8,
kv_cache_fp8,
return _fp8_fp4_paged_mqa_logits_impl(
q,
kv_cache,
weights,
context_lens,
block_tables,
......@@ -400,6 +457,32 @@ def fp8_paged_mqa_logits(
)
def tf32_hc_prenorm_gemm(
x: torch.Tensor,
fn: torch.Tensor,
out: torch.Tensor,
sqrsum: torch.Tensor,
num_split: int,
) -> torch.Tensor:
"""
Perform the following computation:
out = x.float() @ fn.T
sqrsum = x.float().square().sum(-1)
See the caller function for shape requirement
"""
_lazy_init()
if _tf32_hc_prenorm_gemm_impl is None:
return _missing()
return _tf32_hc_prenorm_gemm_impl(
x,
fn,
out,
sqrsum,
num_split,
)
def _ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
......@@ -482,10 +565,12 @@ __all__ = [
"calc_diff",
"DeepGemmQuantScaleFMT",
"fp8_gemm_nt",
"fp8_einsum",
"m_grouped_fp8_gemm_nt_contiguous",
"m_grouped_fp8_fp4_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
"fp8_mqa_logits",
"fp8_paged_mqa_logits",
"fp8_fp4_mqa_logits",
"fp8_fp4_paged_mqa_logits",
"get_paged_mqa_logits_metadata",
"per_block_cast_to_fp8",
"is_deep_gemm_e8m0_used",
......
......@@ -2,11 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from enum import Enum
from typing import Any
import torch
class AuxStreamType(Enum):
Attention = 1
class EventType(Enum):
Main = 0
Attention = 1
def maybe_execute_in_parallel(
fn0: Callable[[], Any],
fn1: Callable[[], Any],
......
......@@ -392,6 +392,11 @@ class CommonAttentionMetadata:
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""
positions: torch.Tensor | None = None
"""(num_actual_tokens,) token positions. Optional; set when the caller
has positions available so that builders can pre-compute position-dependent
metadata (e.g. C128A topk indices for DeepSeek V4)."""
is_prefilling: torch.Tensor | None = None
"""(batch_size,) bool tensor: True if request is still in prefill phase
(num_computed_tokens < num_prompt_tokens). Used by some backends to
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _compressed_slot_mapping_kernel(
# [num_tokens]
slot_mapping_ptr,
# [num_reqs + 1]
query_start_loc_ptr,
# [num_reqs]
seq_lens_ptr,
# [num_reqs, max_num_blocks]
block_table_ptr,
block_table_stride,
block_size,
COMPRESS_RATIO: tl.constexpr,
PAD_ID: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
seq_len = tl.load(seq_lens_ptr + batch_idx)
start_pos = seq_len - query_len
for i in range(0, query_len, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
mask = offset < query_len
pos = start_pos + i + tl.arange(0, TRITON_BLOCK_SIZE)
is_valid = (pos + 1) % COMPRESS_RATIO == 0
pos_after_compress = pos // COMPRESS_RATIO
block_ids = pos_after_compress // block_size
block_numbers = tl.load(
block_table_ptr + batch_idx * block_table_stride + block_ids,
mask=mask & is_valid,
)
slot_ids = block_numbers * block_size + pos_after_compress % block_size
# NOTE
slot_ids = tl.where(is_valid, slot_ids, PAD_ID)
tl.store(slot_mapping_ptr + query_start + offset, slot_ids, mask=mask)
def get_compressed_slot_mapping(
num_tokens: int,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
block_table: torch.Tensor,
block_size: int,
compress_ratio: int,
out: torch.Tensor | None = None,
) -> torch.Tensor:
if out is not None:
# Guard: for padded / invalid sequences.
# Negative positions produce bogus block indices that lead to illegal memory
# accesses inside the block_table load.
# NOTE: Fill -1 to the whole tensor, not just the first `num_tokens`.
out.fill_(-1)
slot_mapping = out[:num_tokens]
else:
slot_mapping = torch.full(
(num_tokens,), -1, dtype=torch.int64, device=query_start_loc.device
)
num_reqs = block_table.shape[0]
_compressed_slot_mapping_kernel[(num_reqs,)](
slot_mapping,
query_start_loc,
seq_lens,
block_table,
block_table.stride(0),
block_size,
compress_ratio,
PAD_ID=-1,
TRITON_BLOCK_SIZE=1024,
)
return slot_mapping
......@@ -15,6 +15,8 @@ from vllm.model_executor.layers.attention.mla_attention import (
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
......@@ -27,6 +29,7 @@ from vllm.v1.attention.backend import (
MultipleOf,
SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.mla.compressor_utils import get_compressed_slot_mapping
from vllm.v1.attention.backends.mla.sparse_utils import (
triton_convert_req_index_to_global_index,
)
......@@ -65,8 +68,8 @@ MIN_HEADS_FOR_BF16_PREFILL = 32
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
structured as:
For DeepSeek V3.2, in the "FP8 with scale" format, each token's KV cache is 656
Bytes, structured as:
- **First 512 bytes:** The "quantized NoPE" part, containing 512
`float8_e4m3` values.
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
......@@ -74,6 +77,16 @@ structured as:
the second for the next 128, and so on.
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
part is not quantized for accuracy.
For DeepSeek V4, in the "FP8 with scale" format, each token's KV cache is 584
Bytes, structured as:
- **First 448 bytes:** The "quantized NoPE" part, containing 448
`float8_e4m3` values.
- **Next 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
part is not quantized for accuracy.
- **Last 8 bytes:** Scale factors, containing 7 `ue8m0` values + 1B pad.
The first `ue8m0` is the scale for the first 64 `float8_e4m3` values,
the second for the next 64, and so on.
"""
......@@ -104,7 +117,8 @@ class FlashMLASparseBackend(AttentionBackend):
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
# V3.2: 576 (512 NoPE + 64 RoPE); DeepseekV4: 512 (448 NoPE + 64 RoPE)
return [512, 576]
@classmethod
def is_mla(cls) -> bool:
......@@ -127,13 +141,37 @@ class FlashMLASparseBackend(AttentionBackend):
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if cache_dtype_str == "fp8_ds_mla":
# custom storage format is 656 bytes
# see FlashMLA readme.md for details
# V3.2 main MLA: 656-byte custom storage format. See module docstring.
return (num_blocks, block_size, 656)
else:
return (num_blocks, block_size, head_size)
class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend):
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [256]
@staticmethod
def get_name() -> str:
return "V4_FLASHMLA_SPARSE"
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if cache_dtype_str == "fp8_ds_mla":
# DeepseekV4 main MLA: 584B per token (448 NoPE + 128 RoPE + 8 fp8 scale).
# head_size passed in is the semantic head_dim (512).
return (num_blocks, block_size, 584)
else:
return (num_blocks, block_size, head_size)
@dataclass
class FlashMLASparseMetadata(AttentionMetadata):
num_reqs: int
......@@ -159,6 +197,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
class FP8SeparatePrefillDecode:
@dataclass
class Decode:
seq_lens: torch.Tensor
kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
decode_query_len: int # needed for reshape in spec decode
......@@ -206,6 +245,13 @@ class FlashMLASparseMetadata(AttentionMetadata):
fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None
fp8_use_mixed_batch: bool = False
# Pre-computed C128A metadata (DeepseekV4 only, compress_ratio == 128).
# Decode: global slot ids + valid-entry counts (fused from positions).
c128a_global_decode_topk_indices: torch.Tensor | None = None
c128a_decode_topk_lens: torch.Tensor | None = None
# Prefill: local topk indices (used by combine_topk_swa_indices).
c128a_prefill_topk_indices: torch.Tensor | None = None
def get_prefill_workspace_size(max_model_len: int):
# NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
......@@ -235,8 +281,9 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
parallel_config = vllm_config.parallel_config
self.device = device
# Treat requests with query length <= 1 as decodes to match the
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
# Classify single-token queries (plus num_speculative_tokens via
# supports_spec_as_decode=True) as decodes; longer queries go to
# prefill.
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
sm_count = num_compute_units(device.index)
......@@ -300,6 +347,68 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
device=device,
)
# DeepseekV4: has compress_ratios in hf_config.
hf_config = vllm_config.model_config.hf_config
self.is_deepseek_v4 = (
hasattr(hf_config, "compress_ratios") and len(hf_config.compress_ratios) > 0
)
self.compress_ratio = 1
if self.is_deepseek_v4:
assert hasattr(self.kv_cache_spec, "compress_ratio")
self.compress_ratio = self.kv_cache_spec.compress_ratio
# Pre-allocate compressed slot mapping buffer for CUDA graph
# address stability when compress_ratio > 1.
if self.compress_ratio > 1:
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens
)
self.compressed_slot_mapping_buffer = torch.empty(
max_num_batched_tokens,
dtype=torch.int64,
device=self.device,
)
# Pre-allocate C128A topk buffers for CUDA graph address stability.
if self.compress_ratio == 128:
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens
)
# Pad to B_TOPK alignment (128 covers both h_q=64 B_TOPK=64 and
# h_q=128 B_TOPK=128). FlashMLA decode asserts extra_topk % B_TOPK
# == 0; unaligned widths (e.g. 17 = ceil(2136/128)) crash the
# sm100 head64 kernel. Padded slots stay -1 and decode_lens caps
# them via topk_length, so the pad is a no-op at kernel level.
# Mirrors _SPARSE_PREFILL_TOPK_ALIGNMENT in cache_utils.py.
_C128A_TOPK_ALIGNMENT = 128
c128a_max_compressed = cdiv(
self.model_config.max_model_len, self.compress_ratio
)
c128a_max_compressed = (
cdiv(c128a_max_compressed, _C128A_TOPK_ALIGNMENT)
* _C128A_TOPK_ALIGNMENT
)
# Stored so _build_c128a_metadata passes it as the kernel's
# max_compressed_tokens, matching the buffer stride. Otherwise
# the kernel's default 8192 iterates past row width and spills
# writes into adjacent rows (present in both decode and prefill
# branches of _build_c128a_topk_metadata_kernel).
self.c128a_max_compressed = c128a_max_compressed
self.c128a_global_decode_buffer = torch.empty(
(max_num_batched_tokens, c128a_max_compressed),
dtype=torch.int32,
device=self.device,
)
self.c128a_decode_lens_buffer = torch.empty(
max_num_batched_tokens,
dtype=torch.int32,
device=self.device,
)
self.c128a_prefill_buffer = torch.empty(
(max_num_batched_tokens, c128a_max_compressed),
dtype=torch.int32,
device=self.device,
)
def _build_fp8_mixed_decode_prefill(
self,
common_attn_metadata: CommonAttentionMetadata,
......@@ -460,15 +569,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
# Use padded head count since that's what the kernel will see
padded_heads = self.fp8_decode_padded_heads
scheduler_metadata, _ = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
num_q_tokens_per_head_k=decode_query_len * padded_heads,
topk=self.topk_tokens,
num_heads_q=padded_heads,
num_heads_k=1,
is_fp8_kvcache=True,
)
scheduler_metadata, _ = get_mla_metadata()
kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=scheduler_metadata,
......@@ -476,6 +577,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
cache_lens=self.max_model_len_tensor[:num_decodes],
)
fp8_metadata.decode = FP8Meta.Decode(
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
kernel_metadata=kernel_meta,
decode_query_len=decode_query_len,
)
......@@ -502,35 +604,109 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
slot_mapping = cm.slot_mapping
if self.compress_ratio > 1:
slot_mapping = get_compressed_slot_mapping(
common_attn_metadata.num_actual_tokens,
common_attn_metadata.query_start_loc,
common_attn_metadata.seq_lens,
common_attn_metadata.block_table_tensor.clamp(min=0),
int(self.kv_cache_spec.storage_block_size),
self.compress_ratio,
out=self.compressed_slot_mapping_buffer,
)
fp8_extra_metadata: (
FlashMLASparseMetadata.FP8SeparatePrefillDecode
| FlashMLASparseMetadata.FP8KernelMetadata
| None
) = None
fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL
if self.use_fp8_kv_cache:
fp8_use_mixed_batch = (
self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and not self.is_deepseek_v4
)
# DeepseekV4 has its own attention impl (DeepseekV4MLAAttention) that does not
# consume fp8_extra_metadata. Skipping the build here avoids a
# forced D2H sync on seq_lens that would otherwise fire on every
# prefill-bearing step, lifting GPU utilization on long-prefill
# workloads (e.g. LongBench) from ~83% to ~100%.
if self.use_fp8_kv_cache and not self.is_deepseek_v4:
if fp8_use_mixed_batch:
fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
else:
fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm)
# Pre-compute C128A topk indices for DeepseekV4.
c128a_fields = {}
if self.is_deepseek_v4 and self.compress_ratio == 128:
c128a_fields = self._build_c128a_metadata(cm, req_id_per_token)
metadata = FlashMLASparseMetadata(
num_reqs=cm.num_reqs,
max_query_len=cm.max_query_len,
max_seq_len=cm.max_seq_len,
num_actual_tokens=cm.num_actual_tokens,
query_start_loc=cm.query_start_loc,
slot_mapping=cm.slot_mapping,
slot_mapping=slot_mapping,
block_table=cm.block_table_tensor,
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
fp8_extra_metadata=fp8_extra_metadata,
fp8_use_mixed_batch=fp8_use_mixed_batch,
**c128a_fields,
)
return metadata
def _build_c128a_metadata(
self,
cm: CommonAttentionMetadata,
req_id_per_token: torch.Tensor,
) -> dict[str, torch.Tensor | None]:
"""Pre-compute C128A topk indices for DeepseekV4 (compress_ratio >= 128)."""
# Must match SWA's decode split (no `require_uniform=True`) so
# `c128a_global_decode_topk_indices.shape[0]` lines up with q in
# `_forward_decode`. The per-token C128A kernel handles non-uniform
# query lengths.
(num_decodes, _, num_decode_tokens, num_prefill_tokens) = (
split_decodes_and_prefills(
cm,
decode_threshold=self.reorder_batch_threshold or 1,
)
)
num_total = num_decode_tokens + num_prefill_tokens
if num_total == 0:
return {}
assert cm.positions is not None, (
"positions is required for C128A metadata build"
)
block_size = self.kv_cache_spec.block_size // self.compress_ratio
global_decode, decode_lens, prefill_local = build_c128a_topk_metadata(
cm.positions[:num_total],
self.compress_ratio,
num_decode_tokens,
req_id_per_token,
cm.block_table_tensor[:num_decodes],
block_size,
cm.slot_mapping,
self.c128a_global_decode_buffer,
self.c128a_decode_lens_buffer,
self.c128a_prefill_buffer,
max_compressed_tokens=self.c128a_max_compressed,
)
result: dict[str, torch.Tensor | None] = {}
if num_decode_tokens > 0:
result["c128a_global_decode_topk_indices"] = global_decode.view(
num_decode_tokens, 1, -1
)
result["c128a_decode_topk_lens"] = decode_lens
if num_prefill_tokens > 0:
result["c128a_prefill_topk_indices"] = prefill_local
return result
class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
@staticmethod
......@@ -552,7 +728,7 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
topk_indice_buffer: torch.Tensor | None = None,
topk_indices_buffer: torch.Tensor | None = None,
indexer: "Indexer | None" = None,
**mla_args,
) -> None:
......@@ -615,7 +791,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
NUM_TOPK_TOKENS=topk_indices.shape[1],
)
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
return self._bf16_flash_mla_kernel(
q,
kv_c_and_k_pe_cache,
topk_indices,
)
def _forward_fp8_kv_separate_prefill_decode(
self,
......@@ -656,7 +836,10 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
fp8_metadata = attn_metadata.fp8_extra_metadata
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
def _fp8_decode(
q: torch.Tensor,
topk_indices: torch.Tensor,
) -> torch.Tensor:
# Reshape q: (num_decode_tokens, num_heads, head_dim)
# -> (num_decodes, seq_len, num_heads, head_dim)
q = reshape_query_for_spec_decode(q, num_decodes)
......@@ -692,7 +875,8 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
if num_decode_tokens > 0:
attn_out[:num_decode_tokens] = _fp8_decode(
q[:num_decode_tokens], topk_indices[:num_decode_tokens]
q[:num_decode_tokens],
topk_indices[:num_decode_tokens],
)
assert fp8_metadata.prefill is not None
......@@ -823,6 +1007,7 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
output = flash_mla_sparse_fwd(
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
)[0]
output = output[:, : self.num_heads, :]
return output
......@@ -864,3 +1049,123 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
)
return attn_out, None
def build_c128a_topk_metadata(
positions: torch.Tensor,
compress_ratio: int,
num_decode_tokens: int,
token_to_req_indices: torch.Tensor,
block_table: torch.Tensor,
block_size: int,
slot_mapping: torch.Tensor,
global_decode_buffer: torch.Tensor,
decode_lens_buffer: torch.Tensor,
prefill_buffer: torch.Tensor,
max_compressed_tokens: int = 8192,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Single kernel for all C128A tokens (decode + prefill).
Decode tokens: position → block_table lookup → global slot ids + topk_lens.
Prefill tokens: position → local indices [0, ..., n-1, -1, ...].
Writes into pre-allocated buffers for CUDA graph address stability.
Returns slices of the buffers.
"""
num_tokens = positions.shape[0]
num_prefill_tokens = num_tokens - num_decode_tokens
global_decode = global_decode_buffer[:num_decode_tokens]
decode_lens = decode_lens_buffer[:num_decode_tokens]
prefill_local = prefill_buffer[:num_prefill_tokens]
if num_tokens == 0:
return global_decode, decode_lens, prefill_local
_build_c128a_topk_metadata_kernel[(num_tokens,)](
global_decode_buffer,
global_decode_buffer.stride(0),
decode_lens_buffer,
prefill_buffer,
prefill_buffer.stride(0),
positions,
compress_ratio,
max_compressed_tokens,
num_decode_tokens,
token_to_req_indices,
block_table,
block_table.stride(0),
block_size,
slot_mapping,
BLOCK_SIZE=1024,
)
return global_decode, decode_lens, prefill_local
@triton.jit
def _build_c128a_topk_metadata_kernel(
# Decode outputs
global_decode_ptr,
global_decode_stride,
decode_lens_ptr,
# Prefill output
prefill_local_ptr,
prefill_local_stride,
# Inputs
positions_ptr,
compress_ratio,
max_compressed_tokens,
num_decode_tokens,
token_to_req_indices_ptr,
block_table_ptr,
block_table_stride,
block_size,
slot_mapping_ptr,
BLOCK_SIZE: tl.constexpr,
):
token_idx = tl.program_id(0)
position = tl.load(positions_ptr + token_idx)
num_compressed = (position + 1) // compress_ratio
num_compressed = tl.minimum(num_compressed, max_compressed_tokens)
is_decode = token_idx < num_decode_tokens
if is_decode:
# --- Decode: block-table lookup → global slot ids + count ---
is_valid_token = tl.load(slot_mapping_ptr + token_idx) >= 0
req_idx = tl.load(token_to_req_indices_ptr + token_idx)
count = tl.zeros((), dtype=tl.int32)
for i in range(0, max_compressed_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
mask = offset < max_compressed_tokens
is_valid = offset < num_compressed
block_indices = offset // block_size
block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices,
mask=mask & is_valid,
)
block_offsets = offset % block_size
slot_ids = block_numbers * block_size + block_offsets
slot_ids = tl.where(is_valid, slot_ids, -1)
tl.store(
global_decode_ptr + token_idx * global_decode_stride + offset,
slot_ids,
mask=mask,
)
count += tl.sum(is_valid.to(tl.int32), axis=0)
tl.store(
decode_lens_ptr + token_idx,
tl.where(is_valid_token, count, 0),
)
else:
# --- Prefill: write local indices ---
pfx_idx = token_idx - num_decode_tokens
for i in range(0, max_compressed_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
mask = offset < max_compressed_tokens
tl.store(
prefill_local_ptr + pfx_idx * prefill_local_stride + offset,
tl.where(offset < num_compressed, offset, -1),
mask=mask,
)
......@@ -22,10 +22,11 @@ from vllm.v1.attention.backend import (
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.mla.compressor_utils import get_compressed_slot_mapping
from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__)
......@@ -154,6 +155,16 @@ class DeepseekV32IndexerBackend(AttentionBackend):
return (0, 1, 2)
class DeepseekV4IndexerBackend(DeepseekV32IndexerBackend):
@staticmethod
def get_name() -> str:
return "DEEPSEEK_V4_INDEXER"
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [256]
@dataclass
class DeepseekV32IndexerPrefillChunkMetadata:
block_table: torch.Tensor
......@@ -179,7 +190,7 @@ class DeepSeekV32IndexerDecodeMetadata:
# seq_lens: per-token effective context lengths.
# - flatten path / plain decode: 1D (batch_size,)
# - native MTP path: 2D (B, next_n) where [b,j] = L_b - next_n + j + 1
# Both fp8_paged_mqa_logits and the topk kernels accept both shapes.
# Both fp8_fp4_paged_mqa_logits and the topk kernels accept both shapes.
seq_lens: torch.Tensor
decode_lens: torch.Tensor
requires_padding: bool
......@@ -191,16 +202,8 @@ class DeepseekV32IndexerMetadata:
# FIXME (zyongye)
# hacky way to access the data now, need to be in chunked meta
seq_lens: torch.Tensor
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# The dimension of the attention heads
head_dim: int
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
......@@ -213,71 +216,6 @@ class DeepseekV32IndexerMetadata:
prefill: DeepseekV32IndexerPrefillMetadata | None = None
# TODO (zyongye) optimize this, this is now vibe coded
def kv_spans_from_batches(
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
start_seq_loc: 1D long tensor [B+1], cumulative counts of
selected tokens per batch.
Example: [0, 2, 4, 7] ->
batch sizes (selected) [2, 2, 3], N=7 tokens total.
seq_len_per_batch: 1D long tensor [B],
full sequence length (KV length) of each batch.
Example: [5, 9, 4].
Returns:
start_tensor: 1D long tensor [N], start offset in the
concatenated KV cache for each token's batch.
end_location: 1D long tensor [N],
**exclusive** end = start + token's local position.
(So the attended KV slice is kv[start:end].)
Assumes each batch contributes its full `seq_len_per_batch[i]`
keys to the KV cache, andthe selected tokens within a batch
are the **last** `counts[i]` positions of that sequence.
"""
q = start_seq_loc.to(dtype=torch.long)
L = seq_len_per_batch.to(dtype=torch.long)
assert q.dim() == 1 and L.dim() == 1
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
# Selected tokens per batch and totals
counts = q[1:] - q[:-1] # [B]
N = int(q[-1].item()) # total selected tokens
B = L.numel()
if N == 0:
return (
torch.empty(0, dtype=torch.long, device=device),
torch.empty(0, dtype=torch.long, device=device),
)
# KV start offsets per batch in the concatenated KV cache
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
# For each selected token, which batch does it belong to?
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
# Map batch KV start to each token
start_tensor = kv_starts_per_batch[batch_id] # [N]
# End-align local positions inside each batch:
# local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b
L_expand = torch.repeat_interleave(L, counts) # [N]
m_expand = torch.repeat_interleave(counts, counts) # [N]
# position within the selected block: 1..counts[b]
pos_within = (
torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1
)
local_pos = L_expand - m_expand + pos_within # [N], 1-based
end_location = start_tensor + local_pos # exclusive end
return start_tensor.int().to(device), end_location.int().to(device)
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
max_model_len = vllm_config.model_config.max_model_len
# NOTE(Chen): 40 is a magic number for controlling the prefill buffer size.
......@@ -293,7 +231,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
reorder_batch_threshold: int = 1
natively_supported_next_n: list[int] = [1, 2]
natively_supported_next_n_fp4: list[int] = [1, 2]
# TODO (matt): integrate kernel with next_n = 4 support
@classmethod
......@@ -314,9 +252,30 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if self.vllm_config.speculative_config
else 0
)
self.use_fp4_indexer_cache = (
self.vllm_config.attention_config.use_fp4_indexer_cache
)
assert (
current_platform.is_device_capability_family(100)
or not self.use_fp4_indexer_cache
), (
"use_fp4_indexer_cache requires Blackwell datacenter GPUs "
"(sm_10x, e.g. B200/GB200); sm_120 (consumer Blackwell) and "
"earlier architectures are not supported."
)
next_n = self.num_speculative_tokens + 1
self.reorder_batch_threshold += self.num_speculative_tokens
self.use_flattening = next_n not in self.natively_supported_next_n
# NOTE(zyongye) fp4 indexer cache only natively supports next_n in
# natively_supported_next_n_fp4; for other next_n values we fall back
# to the flattening path. Outside the SM100 datacenter family the FP8
# paged MQA logits kernel has the same [1, 2] constraint (deepgemm
# smxx_fp8_fp4_paged_mqa_logits.hpp:233), so flatten there too.
self.use_flattening = (
self.use_fp4_indexer_cache
or not current_platform.is_device_capability_family(100)
) and next_n not in self.natively_supported_next_n_fp4
sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count
......@@ -331,7 +290,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
)
if not self.use_flattening and next_n > 1:
# Native MTP: 2D buffer for per-token seq_lens.
# Flattening path is never used, so no expanded_seq_lens_buffer.
self.decode_seq_lens_buffer = torch.zeros(
(scheduler_config.max_num_seqs, next_n),
dtype=torch.int32,
......@@ -367,53 +325,27 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
(self.num_sms + 1, 2), dtype=torch.int32, device=self.device
)
def build_one_prefill_chunk(
self,
req_slice: slice,
query_slice: slice,
query_start_loc_cpu,
seq_lens_cpu,
block_table,
skip_kv_gather: bool = False,
) -> DeepseekV32IndexerPrefillChunkMetadata:
prefill_query_start_loc = (
query_start_loc_cpu[req_slice.start : req_slice.stop + 1]
- query_start_loc_cpu[req_slice.start]
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc, seq_lens_cpu[req_slice], self.device
)
token_start = query_start_loc_cpu[req_slice.start].item()
total_seq_lens = seq_lens_cpu[req_slice].sum()
num_reqs = req_slice.stop - req_slice.start
seq_idx = torch.arange(0, num_reqs, dtype=torch.int32)
token_to_seq = torch.repeat_interleave(seq_idx, seq_lens_cpu[req_slice]).to(
self.device
)
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = (
torch.cat(
[
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[req_slice].cumsum(dim=0),
]
# KV compression. Default to 1 for no compression.
self.compress_ratio = 1
# Get compress_ratio for DeepseekV4 support
if isinstance(self.kv_cache_spec, MLAAttentionSpec):
self.compress_ratio = self.kv_cache_spec.compress_ratio
# Pre-allocate buffers for CUDA graph compatibility when
if self.compress_ratio > 1:
# compress_ratio > 1 (DeepseekV4)
# Compressed slot mapping output buffer
self.compressed_slot_mapping_buffer = torch.zeros(
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int64,
device=self.device,
)
# Buffer for compressed seq_lens in decode path
self.expanded_seq_lens_buffer = torch.zeros(
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
.to(torch.int32)
.to(self.device)
)
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks[query_slice],
cu_seqlen_ke=cu_seqlen_ke[query_slice],
cu_seq_lens=cu_seq_lens,
token_to_seq=token_to_seq,
total_seq_lens=total_seq_lens,
block_table=block_table[req_slice],
token_start=token_start + query_slice.start,
token_end=token_start + query_slice.stop,
num_reqs=num_reqs,
skip_kv_gather=skip_kv_gather,
)
def _prepare_decode_tensors(
self,
......@@ -520,11 +452,15 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
requires_padding = min_decode_len != max_decode_len
if use_native and next_n > 1:
assert self.decode_seq_lens_buffer.dim() == 2
# (B, next_n): token j attends to L - next_n + j + 1 KV tokens
self.decode_seq_lens_buffer[:num_decodes] = (
seq_lens.unsqueeze(1) - next_n + 1 + self.offsets_buffer
# (B, max_decode_len): token j attends to
# L - max_decode_len + j + 1 KV tokens.
self.decode_seq_lens_buffer[:num_decodes, :max_decode_len] = (
seq_lens.unsqueeze(1)
- max_decode_len
+ 1
+ self.offsets_buffer[:max_decode_len]
)
seq_lens = self.decode_seq_lens_buffer[:num_decodes]
seq_lens = self.decode_seq_lens_buffer[:num_decodes, :max_decode_len]
return seq_lens, block_table, decode_lens, num_decodes, requires_padding
def build(
......@@ -535,8 +471,12 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
) -> DeepseekV32IndexerMetadata:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
slot_mapping = common_attn_metadata.slot_mapping
block_table = common_attn_metadata.block_table_tensor
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
......@@ -548,8 +488,32 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
compressed_slot_mapping = slot_mapping
compressed_seq_lens = seq_lens
if self.compress_ratio > 1:
compressed_slot_mapping = get_compressed_slot_mapping(
num_tokens,
query_start_loc,
seq_lens,
block_table,
self.kv_cache_spec.storage_block_size,
self.compress_ratio,
out=self.compressed_slot_mapping_buffer,
)
compressed_seq_lens = seq_lens // self.compress_ratio
prefill_metadata = None
if num_prefills > 0:
# This CPU value is an upper bound for async-spec extend rows. It
# is safe for chunking/allocation because CUDA metadata below is
# built from exact device seq_lens and gather ignores the tail.
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
compressed_seq_lens_cpu = (
seq_lens_cpu // self.compress_ratio
if self.compress_ratio > 1
else seq_lens_cpu
)
prefill_query_lens_cpu = torch.diff(
query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1]
)
......@@ -559,26 +523,32 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
chunk_specs = split_indexer_prefill_chunks(
seq_lens_cpu[num_decodes:],
compressed_seq_lens_cpu[num_decodes:],
prefill_query_lens_cpu,
self.max_prefill_buffer_size,
max_logits_bytes,
request_offset=num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
req_slice,
query_slice,
chunks = []
for req_slice, query_slice in chunk_specs:
metadata = build_prefill_chunk_metadata(
req_slice.start,
req_slice.stop,
query_start_loc,
query_start_loc_cpu,
seq_lens_cpu,
seq_lens,
compressed_seq_lens,
compressed_seq_lens_cpu,
common_attn_metadata.block_table_tensor,
self.compress_ratio,
query_slice=query_slice,
skip_kv_gather=query_slice.start > 0,
)
for req_slice, query_slice in chunk_specs
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks,
)
# Skip when total_seq_lens is 0 (i.e., no compressed token).
if metadata is not None:
chunks.append(metadata)
prefill_metadata = DeepseekV32IndexerPrefillMetadata(chunks)
decode_metadata = None
if num_decodes > 0:
......@@ -596,7 +566,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
max_decode_len = int(decode_lens_cpu.max().item())
next_n = 1 + self.num_speculative_tokens
use_native = not self.use_flattening and max_decode_len == next_n
use_native = not self.use_flattening and max_decode_len <= next_n
seq_lens, block_table, decode_lens, batch_size, requires_padding = (
self._prepare_decode_tensors(
......@@ -613,11 +583,35 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
)
)
# For DeepseekV4 (compress_ratio > 1), the indexer KV cache stores
# compressed tokens. Convert uncompressed seq_lens to compressed.
if self.compress_ratio > 1:
# True iff seq_lens aliases decode_seq_lens_buffer (flatten or
# native wrote it); False iff it aliases common_attn_metadata.
seq_lens_is_local_view = (use_native and next_n > 1) or (
not use_native and max_decode_len > 1
)
if seq_lens_is_local_view:
seq_lens //= self.compress_ratio
else:
# Copy to avoid mutating shared state; keeps CG address stable.
self.expanded_seq_lens_buffer[:num_decodes] = (
seq_lens // self.compress_ratio
)
self.expanded_seq_lens_buffer[num_decodes:num_decode_tokens] = 0
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
# Non-MTP: deep_gemm paged MQA logits requires 2D context_lens
# (csrc/apis/attention.hpp). Unsqueeze to (B, 1) so downstream
# kernels see the same (B, next_n) layout as the MTP path.
if seq_lens.dim() == 1:
seq_lens = seq_lens.unsqueeze(-1)
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and has_deep_gemm():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens,
self.kv_cache_spec.block_size,
self.kv_cache_spec.storage_block_size,
self.num_sms,
)
......@@ -631,13 +625,8 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
attn_metadata = DeepseekV32IndexerMetadata(
seq_lens=common_attn_metadata.seq_lens,
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128,
slot_mapping=compressed_slot_mapping,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
......@@ -647,3 +636,138 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
)
return attn_metadata
def build_prefill_chunk_metadata(
start_idx: int,
end_idx: int,
query_start_loc: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
uncompressed_seq_lens: torch.Tensor,
compressed_seq_lens: torch.Tensor,
compressed_seq_lens_cpu: torch.Tensor,
block_table: torch.Tensor,
compress_ratio: int,
query_slice: slice | None = None,
skip_kv_gather: bool = False,
) -> DeepseekV32IndexerPrefillChunkMetadata | None:
total_seq_lens = compressed_seq_lens_cpu[start_idx:end_idx].sum().item()
if total_seq_lens == 0:
return None
num_reqs = end_idx - start_idx
device = block_table.device
token_to_seq = torch.empty(total_seq_lens, dtype=torch.int32, device=device)
cu_seq_lens = torch.empty(num_reqs + 1, dtype=torch.int32, device=device)
# Assigning to slice avoids cpu sync.
cu_seq_lens[:1] = 0
torch.cumsum(compressed_seq_lens[start_idx:end_idx], dim=0, out=cu_seq_lens[1:])
query_start_loc = (
query_start_loc[start_idx : end_idx + 1] - query_start_loc[start_idx]
)
total_query_len = int(
(query_start_loc_cpu[end_idx] - query_start_loc_cpu[start_idx]).item()
)
if query_slice is not None:
qs_start = query_slice.start
qs_stop = query_slice.stop
else:
qs_start = 0
qs_stop = total_query_len
output_query_len = qs_stop - qs_start
cu_seq_len_ks = torch.empty(output_query_len, dtype=torch.int32, device=device)
cu_seq_len_ke = torch.empty(output_query_len, dtype=torch.int32, device=device)
_build_prefill_chunk_metadata_kernel[(num_reqs,)](
query_start_loc,
uncompressed_seq_lens[start_idx:end_idx],
cu_seq_lens,
token_to_seq,
cu_seq_len_ks,
cu_seq_len_ke,
qs_start,
qs_stop,
BLOCK_SIZE=1024,
COMPRESS_RATIO=compress_ratio,
)
token_start = query_start_loc_cpu[start_idx].item()
if query_slice is not None:
token_end = token_start + qs_stop
token_start = token_start + qs_start
skip_kv_gather = skip_kv_gather or qs_start > 0
else:
token_end = query_start_loc_cpu[end_idx].item()
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seq_len_ks,
cu_seqlen_ke=cu_seq_len_ke,
cu_seq_lens=cu_seq_lens,
token_to_seq=token_to_seq,
total_seq_lens=total_seq_lens,
block_table=block_table[start_idx:end_idx],
token_start=token_start,
token_end=token_end,
num_reqs=num_reqs,
skip_kv_gather=skip_kv_gather,
)
@triton.jit
def _build_prefill_chunk_metadata_kernel(
# Inputs
query_start_loc_ptr,
uncompressed_seq_lens_ptr,
cu_compressed_seq_lens_ptr,
# Outputs
token_to_seq_ptr,
cu_compressed_seq_len_ks_ptr,
cu_compressed_seq_len_ke_ptr,
query_slice_start,
query_slice_stop,
BLOCK_SIZE: tl.constexpr,
COMPRESS_RATIO: tl.constexpr,
):
batch_idx = tl.program_id(0)
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
seq_start = tl.load(cu_compressed_seq_lens_ptr + batch_idx)
seq_end = tl.load(cu_compressed_seq_lens_ptr + batch_idx + 1)
compressed_seq_len = seq_end - seq_start
uncompressed_seq_len = tl.load(uncompressed_seq_lens_ptr + batch_idx)
start_pos = uncompressed_seq_len - query_len
for i in range(0, query_len, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
abs_pos = query_start + offset
mask = (
(offset < query_len)
& (abs_pos >= query_slice_start)
& (abs_pos < query_slice_stop)
)
out_pos = abs_pos - query_slice_start
# Compute cu_seq_len_ks
tl.store(cu_compressed_seq_len_ks_ptr + out_pos, seq_start, mask=mask)
# Compute cu_seq_len_ke
seq_len_per_token = (start_pos + 1 + offset) // COMPRESS_RATIO
tl.store(
cu_compressed_seq_len_ke_ptr + out_pos,
seq_start + seq_len_per_token,
mask=mask,
)
# Compute token_to_seq
for i in range(0, compressed_seq_len, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
mask = offset < compressed_seq_len
tl.store(token_to_seq_ptr + seq_start + offset, batch_idx, mask=mask)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, cast
import torch
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
from vllm.v1.attention.ops.flashmla import FlashMLASchedMeta, get_mla_metadata
from vllm.v1.kv_cache_interface import (
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowMLASpec,
)
# DeepseekV4 decode layer types, keyed by compress_ratio. Each type has a distinct
# (topk, extra_topk, extra_page_block_size) config, so they cannot share a
# FlashMLA tile-scheduler plan. Within a type, all ~60 DeepseekV4 layers share one
# plan per step because b / s_q / h_q / page_block_sizes / topks are identical.
_LAYER_TYPE_SWAONLY = "swaonly"
_LAYER_TYPE_C4A = "c4a"
_LAYER_TYPE_C128A = "c128a"
def _layer_type_for(compress_ratio: int) -> str:
if compress_ratio <= 1:
return _LAYER_TYPE_SWAONLY
if compress_ratio == 4:
return _LAYER_TYPE_C4A
if compress_ratio == 128:
return _LAYER_TYPE_C128A
raise ValueError(
f"Unsupported DeepseekV4 compress_ratio={compress_ratio}; "
"expected 1, 4, or 128."
)
class DeepseekV4SWACache(torch.nn.Module, AttentionLayerBase):
def __init__(
self,
head_dim: int,
window_size: int,
dtype: torch.dtype,
prefix: str,
cache_config: CacheConfig,
):
super().__init__()
self.kv_cache = torch.tensor([])
self.head_dim = head_dim
self.window_size = window_size
self.prefix = prefix
self.cache_config = cache_config
self.dtype = dtype
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# Block size is constrained by tensor sharing between SWA and C4A KV blocks.
# Since both block types share the same physical tensor, they must use the
# same page size. The C4A KV block shape [256//4, head_dim] = [64, head_dim]
# determines the SWA block size of 64 tokens per block.
# TODO(yifan): make SWA block size automatically determined and configurable.
self.block_size = 64
assert self.dtype == torch.uint8
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return SlidingWindowMLASpec(
block_size=self.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
sliding_window=self.window_size,
cache_dtype_str=self.cache_config.cache_dtype,
alignment=576, # NOTE: FlashMLA requires 576B alignment
model_version="deepseek_v4",
)
def forward(self): ...
def get_attn_backend(self) -> type[AttentionBackend]:
return DeepseekSparseSWABackend
class DeepseekSparseSWABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "DEEPSEEK_SPARSE_SWA"
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(64)]
@classmethod
def get_preferred_block_size(cls, default_block_size: int) -> int:
return 256
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [512]
@staticmethod
def get_builder_cls() -> type["DeepseekSparseSWAMetadataBuilder"]:
return DeepseekSparseSWAMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
assert num_kv_heads == 1
if cache_dtype_str == "fp8_ds_mla":
# DeepseekV4 SWA: 584B per token (448 NoPE + 128 RoPE + 8 fp8 scale).
# head_size passed in is the semantic head_dim (512).
return (num_blocks, block_size, 584)
else:
return (num_blocks, block_size, head_size)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
if include_num_layers_dimension:
return (0, 1, 2, 3)
return (0, 1, 2)
@dataclass
class DeepseekSparseSWAMetadata:
block_table: torch.Tensor
slot_mapping: torch.Tensor
block_size: int
seq_lens: torch.Tensor | None = None # [num_seqs]
query_start_loc: torch.Tensor | None = None # [num_seqs + 1]
query_start_loc_cpu: torch.Tensor | None = None # [num_seqs + 1]
is_valid_token: torch.Tensor | None = None # [num_tokens]
token_to_req_indices: torch.Tensor | None = None # [num_tokens]
decode_swa_indices: torch.Tensor | None = None # [num_decode_tokens, window_size]
decode_swa_lens: torch.Tensor | None = None # [num_decode_tokens]
# Number of decode/prefill requests/tokens (batch is reordered: decodes first)
num_decodes: int = 0
num_prefills: int = 0
num_decode_tokens: int = 0
num_prefill_tokens: int = 0
# Pre-computed prefill metadata shared across all DeepseekV4 attention layers.
prefill_seq_lens: torch.Tensor | None = None
prefill_gather_lens: torch.Tensor | None = None
# Per-layer-type FlashMLA tile-scheduler metadata. One FlashMLASchedMeta
# per present DeepseekV4 layer type, shared across all ~60 layers of that type
# within a decode step. The first forward call of a given type triggers
# the in-kernel planner (which also allocates tile_scheduler_metadata and
# num_splits via PyTorch's graph-aware allocator); subsequent same-type
# calls skip planning and reuse the plan. Fresh instance per build(), so
# have_initialized is always False at the start of a step and the plan
# is re-derived from current seq_lens / topk_length on replay.
# None for layer types the model does not use (or when num_decode_tokens
# is zero).
tile_sched_swaonly: "FlashMLASchedMeta | None" = None
tile_sched_c4a: "FlashMLASchedMeta | None" = None
tile_sched_c128a: "FlashMLASchedMeta | None" = None
class DeepseekSparseSWAMetadataBuilder(AttentionMetadataBuilder):
"""Builds metadata for DeepseekV4 SWA cache.
Similar to the indexer, this handles mixed batches by:
1. Using split_decodes_and_prefills() to determine the boundary
2. Building separate metadata for decode and prefill portions
Supports:
- Mixed decode/prefill batches
- MTP (Multi-Token Prediction) where decode has query_len > 1
- Chunked prefill (aligns with the indexer's chunking)
"""
# Base threshold: query_len <= 1 is decode
reorder_batch_threshold: int = 1
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec)
mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec)
self.head_size = mla_spec.head_size # Already considered quantization.
self.compress_ratio = mla_spec.compress_ratio
self.block_size = mla_spec.block_size
# Handle MTP: adjust decode_threshold like the indexer does
self.num_speculative_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config
else 0
)
# With MTP, decode can have query_len up to 1 + num_speculative_tokens.
# Must match the threshold used by the indexer and flashmla_sparse so
# that all backends agree on the decode/prefill split.
self.decode_threshold = (
self.reorder_batch_threshold + self.num_speculative_tokens
)
hf_config = self.vllm_config.model_config.hf_config
assert hasattr(hf_config, "sliding_window")
self.window_size = hf_config.sliding_window
# Detect which DeepseekV4 layer types this model uses so we only build a
# FlashMLA tile-scheduler plan for types that will actually be called.
# Models without compress_ratios (pure SWA) fall back to swaonly.
compress_ratios = getattr(hf_config, "compress_ratios", None) or [1]
self._layer_types: set[str] = set()
for ratio in compress_ratios:
self._layer_types.add(_layer_type_for(int(ratio)))
max_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
self.token_to_req_indices = torch.zeros(
max_tokens,
dtype=torch.int32,
device=self.device,
)
self.decode_swa_indices = torch.zeros(
max_tokens,
1,
self.window_size,
dtype=torch.int32,
device=self.device,
)
self.decode_swa_lens = torch.zeros(
max_tokens,
dtype=torch.int32,
device=self.device,
)
self.is_valid_token = torch.zeros(
max_tokens,
dtype=torch.bool,
device=self.device,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> DeepseekSparseSWAMetadata:
"""Build SWA metadata for mixed decode/prefill batches.
The batch is assumed to be reordered with decodes first (by vLLM scheduler).
We use split_decodes_and_prefills() to find the boundary, then build
separate window_topk_idxs for each portion.
For prefill, we use chunked prefill to align with the indexer's chunking.
"""
num_reqs = common_attn_metadata.num_reqs
seq_lens = common_attn_metadata.seq_lens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
# Split into decode and prefill portions using configurable threshold
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.decode_threshold
)
)
# NOTE: Ensure all metadata tensors maintain fixed memory addresses
# for CUDA graph compatibility.
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
x = torch.repeat_interleave(torch.arange(num_reqs), query_lens).pin_memory()
token_to_req_indices = self.token_to_req_indices[: x.shape[0]]
token_to_req_indices.copy_(x, non_blocking=True)
is_valid_token = self.is_valid_token[: slot_mapping.shape[0]]
is_valid_token.copy_(slot_mapping >= 0)
if num_decode_tokens > 0:
self.decode_swa_lens[num_decode_tokens:] = 0
_compute_swa_indices_and_lens_kernel[(num_decode_tokens,)](
self.decode_swa_indices,
self.decode_swa_indices.stride(0),
self.decode_swa_lens,
self.window_size,
query_start_loc,
seq_lens,
token_to_req_indices,
is_valid_token,
block_table,
block_table.stride(0),
self.block_size,
TRITON_BLOCK_SIZE=1024,
)
# Pre-compute DeepseekV4 prefill metadata shared across all attention layers.
deepseek_v4_fields = self._build_deepseek_v4_metadata(
num_decodes,
num_prefills,
seq_lens,
query_start_loc,
)
# Per-layer-type tile-scheduler plan holders. Empty FlashMLASchedMeta
# per present DeepseekV4 layer type; the first flash_mla_with_kvcache call of
# each type triggers the planner and all same-type layers reuse the
# resulting plan for the rest of the step.
tile_sched = self.build_tile_scheduler(num_decode_tokens)
return DeepseekSparseSWAMetadata(
seq_lens=seq_lens,
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
block_table=block_table,
slot_mapping=slot_mapping,
is_valid_token=is_valid_token,
token_to_req_indices=token_to_req_indices,
decode_swa_indices=self.decode_swa_indices[:num_decode_tokens],
decode_swa_lens=self.decode_swa_lens[:num_decode_tokens],
block_size=self.block_size,
num_decodes=num_decodes,
num_prefills=num_prefills,
num_decode_tokens=num_decode_tokens,
num_prefill_tokens=num_prefill_tokens,
tile_sched_swaonly=tile_sched[_LAYER_TYPE_SWAONLY],
tile_sched_c4a=tile_sched[_LAYER_TYPE_C4A],
tile_sched_c128a=tile_sched[_LAYER_TYPE_C128A],
**deepseek_v4_fields,
)
def build_tile_scheduler(
self, num_decode_tokens: int
) -> dict[str, FlashMLASchedMeta | None]:
"""Allocate one empty ``FlashMLASchedMeta`` per present DeepseekV4 layer type.
Returned instances have ``tile_scheduler_metadata`` / ``num_splits``
set to ``None``; the FlashMLA C++ decode path will allocate them and
run the tile-scheduler planner on the first ``flash_mla_with_kvcache``
call of each type. Subsequent same-type calls reuse the plan because
the tensors (and ``have_initialized``) are populated on the struct.
Returns all-``None`` when there are no decode tokens this step, so
``_forward_decode`` sees a clean sentinel.
"""
out: dict[str, FlashMLASchedMeta | None] = {
_LAYER_TYPE_SWAONLY: None,
_LAYER_TYPE_C4A: None,
_LAYER_TYPE_C128A: None,
}
if num_decode_tokens == 0:
return out
for layer_type in self._layer_types:
# get_mla_metadata() is the official FlashMLA entry point that
# returns a fresh empty FlashMLASchedMeta; using it keeps this
# call site aligned with the rest of the vLLM FlashMLA backends
# that already go through the same stub.
out[layer_type] = get_mla_metadata()[0]
return out
def _build_deepseek_v4_metadata(
self,
num_decodes: int,
num_prefills: int,
seq_lens: torch.Tensor,
query_start_loc: torch.Tensor,
) -> dict[str, torch.Tensor | None]:
"""Pre-compute DeepseekV4 prefill metadata during the metadata build phase.
Returns a dict of keyword arguments to pass to the
DeepseekSparseSWAMetadata constructor.
Note: C128A topk indices are computed by the FlashMLASparse builder
(which owns the C128A block_table), not here.
"""
result: dict[str, torch.Tensor | None] = {}
# --- Prefill query metadata (single Triton kernel + CPU slicing) ---
if num_prefills > 0:
pfx_gather_lens = torch.empty(
num_prefills, dtype=torch.int32, device=seq_lens.device
)
_compute_prefill_metadata_kernel[(1,)](
pfx_gather_lens,
seq_lens,
query_start_loc,
num_prefills,
num_decodes,
self.window_size,
BLOCK_SIZE=triton.next_power_of_2(num_prefills),
)
result["prefill_seq_lens"] = seq_lens[num_decodes:]
result["prefill_gather_lens"] = pfx_gather_lens
return result
@triton.jit
def _compute_prefill_metadata_kernel(
# Outputs
prefill_gather_lens_ptr,
# Inputs
seq_lens_ptr,
query_start_loc_ptr,
num_prefills,
num_decodes,
window_size,
BLOCK_SIZE: tl.constexpr,
):
"""Compute prefill gather_lens in a single pass."""
offset = tl.arange(0, BLOCK_SIZE)
mask = offset < num_prefills
seq_len = tl.load(seq_lens_ptr + num_decodes + offset, mask=mask)
qsl_start = tl.load(query_start_loc_ptr + num_decodes + offset, mask=mask)
qsl_end = tl.load(query_start_loc_ptr + num_decodes + offset + 1, mask=mask)
query_len = qsl_end - qsl_start
prefix_len = seq_len - query_len
gather_len = query_len + tl.minimum(prefix_len, window_size - 1)
tl.store(prefill_gather_lens_ptr + offset, gather_len, mask=mask)
@triton.jit
def _compute_swa_indices_and_lens_kernel(
swa_indices_ptr,
swa_indices_stride,
swa_lens_ptr,
window_size,
query_start_loc_ptr,
seq_lens_ptr,
token_to_req_indices_ptr,
is_valid_token_ptr,
block_table_ptr,
block_table_stride,
block_size,
TRITON_BLOCK_SIZE: tl.constexpr,
):
token_idx = tl.program_id(0)
is_valid = tl.load(is_valid_token_ptr + token_idx)
if not is_valid:
tl.store(swa_lens_ptr + token_idx, 0)
return
req_idx = tl.load(token_to_req_indices_ptr + token_idx)
query_start = tl.load(query_start_loc_ptr + req_idx)
query_end = tl.load(query_start_loc_ptr + req_idx + 1)
query_len = query_end - query_start
seq_len = tl.load(seq_lens_ptr + req_idx)
prefix_len = seq_len - query_len
pos = prefix_len + token_idx - query_start
start_pos = tl.maximum(pos - window_size + 1, 0)
end_pos = pos + 1
swa_len = end_pos - start_pos
tl.store(swa_lens_ptr + token_idx, swa_len)
for i in range(0, window_size, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
pos_offset = start_pos + offset
block_indices = pos_offset // block_size
block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices,
mask=pos_offset < end_pos,
)
block_offsets = pos_offset % block_size
slot_ids = block_numbers * block_size + block_offsets
slot_ids = tl.where(offset < swa_len, slot_ids, -1)
tl.store(
swa_indices_ptr + token_idx * swa_indices_stride + offset,
slot_ids,
mask=offset < window_size,
)
......@@ -356,7 +356,7 @@ def make_local_attention_virtual_batches(
block_table_tensor=block_table_local,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
seq_lens_cpu_upper_bound=seq_lens_cpu,
seq_lens_cpu_upper_bound=common_attn_metadata.seq_lens_cpu_upper_bound,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
), make_block_table
......
......@@ -265,6 +265,7 @@ def _pack_seq_kernel(
D: tl.constexpr,
Lmax: tl.constexpr,
PAD_VALUE: tl.constexpr,
PAD_IS_UINT8: tl.constexpr,
BLOCK_T: tl.constexpr, # timesteps per program
BLOCK_D: tl.constexpr, # features per program
):
......@@ -294,9 +295,15 @@ def _pack_seq_kernel(
# out_ptr: row-major [B, Lmax, D]
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
# Initialize with PAD. PAD_IS_UINT8 selects the pad tensor's dtype so
# integer-typed outputs (e.g. MXFP4 packed nibbles, ue8m0 scale bytes)
# get an exact-byte pad rather than going through an fp32→uint8 cast
# that's implementation-defined outside of value 0.
d_mask = off_d[None, :] < D
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
if PAD_IS_UINT8:
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.uint8)
else:
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
# Load & write only where within seq_len
......@@ -307,23 +314,36 @@ def _pack_seq_kernel(
def pack_seq_triton(
x: torch.Tensor,
lengths: torch.Tensor,
pad_value: float = -float("inf"),
pad_value: float | int = -float("inf"),
block_t: int = 64,
block_d: int = 64,
) -> torch.Tensor:
"""
Pack sequences of different lengths into a batched tensor.
"""Pack sequences of different lengths into a batched tensor.
Supports float dtypes (any, via fp32 pad) and ``torch.uint8`` (exact-byte
pad — e.g. MXFP4 packed nibbles or ue8m0 scale bytes). For uint8 inputs
``pad_value`` must be an integer in ``[0, 255]``.
Args:
x: [N, ...] - input tensor where N is total number of tokens
lengths: [B] - sequence lengths for each batch
pad_value: value to use for padding
block_t: block size for time dimension
block_d: block size for feature dimension
x: [N, ...] — input tensor where N is total number of tokens.
lengths: [B] — sequence lengths for each batch.
pad_value: value to use for padding. Defaults to ``-inf`` which is
only sensible for float dtypes; pass ``0`` (or any byte) for
uint8 inputs.
block_t: block size for time dimension.
block_d: block size for feature dimension.
Returns:
packed: [B, Lmax, ...] - packed tensor
packed: [B, Lmax, ...] packed tensor.
"""
is_uint8 = x.dtype == torch.uint8
if is_uint8:
assert isinstance(pad_value, int) and 0 <= pad_value <= 255, (
f"uint8 pack requires an integer pad in [0, 255], got {pad_value!r}"
)
pad_constexpr: int | float = int(pad_value)
else:
pad_constexpr = float(pad_value)
# Handle multi-dimensional input by reshaping to (N, -1)
original_shape = x.shape
......@@ -338,8 +358,6 @@ def pack_seq_triton(
B = lengths.numel()
Lmax = int(lengths.max().item())
# Starts are computed inside the kernel from lengths
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
......@@ -350,17 +368,16 @@ def pack_seq_triton(
N,
D,
Lmax,
PAD_VALUE=float(pad_value),
PAD_VALUE=pad_constexpr,
PAD_IS_UINT8=is_uint8,
BLOCK_T=block_t,
BLOCK_D=block_d,
num_warps=4,
num_stages=2,
)
# Reshape output back to original dimensions (except first dimension)
if len(original_shape) > 2:
output_shape = (B, Lmax) + original_shape[1:]
out = out.reshape(output_shape)
out = out.reshape((B, Lmax) + original_shape[1:])
return out
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .cache_utils import (
combine_topk_swa_indices,
compute_global_topk_indices_and_lens,
dequantize_and_gather_k_cache,
quantize_and_insert_k_cache,
)
from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant
from .fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant
from .fused_qk_rmsnorm import fused_q_kv_rmsnorm
__all__ = [
"MXFP4_BLOCK_SIZE",
"combine_topk_swa_indices",
"compute_global_topk_indices_and_lens",
"dequantize_and_gather_k_cache",
"fused_indexer_q_rope_quant",
"fused_inv_rope_fp8_quant",
"fused_q_kv_rmsnorm",
"quantize_and_insert_k_cache",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Triton kernels for DeepseekV4 paged K-cache management and sparse-attention index
preparation.
- quantize_and_insert_k_cache: quantize bf16 K to UE8M0 FP8 and insert into
the paged cache.
- dequantize_and_gather_k_cache: gather and dequantize FP8 K from the paged
cache for sparse/SWA prefill.
- compute_global_topk_indices_and_lens: map local topk indices to global KV
cache slots and count valid entries.
- combine_topk_swa_indices: concatenate topk compressed indices with SWA
window indices for sparse prefill.
"""
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def quantize_and_insert_k_kernel(
# Input tensors
k_ptr, # [num_tokens, 512] bf16
slot_mapping_ptr, # [num_tokens] int64
# Output tensor
k_cache_ptr, # [num_blocks, block_bytes] as uint8 (flattened view)
# Dimensions
num_tokens,
input_dim: tl.constexpr, # 512
fp8_dim: tl.constexpr, # 448
bf16_dim: tl.constexpr, # 64
scale_dim: tl.constexpr, # 8
quant_block: tl.constexpr, # 64 (quantization block size)
cache_block_size: tl.constexpr, # 64 (paged cache block size)
token_data_size: tl.constexpr, # 576 bytes per token data
block_stride: tl.constexpr, # total bytes per block (padded)
fp8_max: tl.constexpr,
n_quant_blocks: tl.constexpr, # 8 (7 real + 1 padding)
):
"""
Quantize K tensor and insert into paged K cache.
K Cache block layout (block_size=64 tokens):
- [0, 64*576): Token data, each token has 448 fp8 + 128 bf16
- [64*576, 64*576 + 64*8): Scales, each token has 8 uint8 scales
- [64*576 + 64*8, block_stride): Padding
One program per token.
"""
pid = tl.program_id(0)
if pid >= num_tokens:
return
# Get slot mapping
slot_idx = tl.load(slot_mapping_ptr + pid)
if slot_idx == -1:
return
block_idx = slot_idx // cache_block_size
pos_in_block = slot_idx % cache_block_size
# Input pointer for this token
input_row_ptr = k_ptr + pid * input_dim
# int64: block_idx * block_stride can exceed 2^31 with many KV-cache blocks
# (e.g. >= 57K at block_stride ~37K). Matches gather path below.
cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride
# Token data pointer: token data is stored contiguously at start of block
# Each token's data is at offset pos_in_block * token_data_size
token_data_ptr = cache_block_ptr + pos_in_block * token_data_size
# Scale pointer: scales are stored after ALL token data in the block
# Scale for this token is at offset (64 * 576) + pos_in_block * 8
token_scale_ptr = (
cache_block_ptr + cache_block_size * token_data_size + pos_in_block * scale_dim
)
# Token data layout: [0:448] fp8, [448:576] bf16
token_fp8_ptr = token_data_ptr
token_bf16_ptr = token_data_ptr + fp8_dim
# ========== Quantize and store FP8 portion (first 448 elements) ==========
# Using UE8M0 quantization strategy (scale is power of 2, stored as uint8 exponent)
for qblock_idx in tl.static_range(n_quant_blocks):
qblock_start = qblock_idx * quant_block
if qblock_start < fp8_dim:
offsets = qblock_start + tl.arange(0, quant_block)
mask = offsets < fp8_dim
# Load bf16 input
x = tl.load(input_row_ptr + offsets, mask=mask, other=0.0)
# Compute absmax scale (same as CUDA kernel)
abs_x = tl.abs(x)
block_max = tl.max(abs_x, axis=0)
block_max = tl.maximum(block_max, 1e-4) # Match CUDA: fmaxf(amax, 1e-4)
# UE8M0: Round scale UP to next power of 2
# scale = 2^ceil(log2(block_max / fp8_max))
raw_scale = block_max / fp8_max
log_scale = tl.log2(raw_scale)
exponent = tl.ceil(log_scale) # Round UP to next integer exponent
scale = tl.exp2(exponent) # scale = 2^exponent (power of 2)
# Quantize to fp8: fp8_value = bf16_value / scale
x_scaled = x / scale
x_clamped = tl.clamp(x_scaled, -fp8_max, fp8_max)
# Convert to fp8, then bitcast to uint8 for storage
x_fp8 = x_clamped.to(tl.float8e4nv)
x_uint8 = x_fp8.to(tl.uint8, bitcast=True)
# Store as uint8 (1 byte each)
tl.store(token_fp8_ptr + offsets, x_uint8, mask=mask)
# UE8M0 scale encoding: stored_value = exponent + 127 (bias)
# During dequant: scale = 2^(stored_value - 127)
encoded_scale = exponent + 127.0
encoded_scale = tl.maximum(tl.minimum(encoded_scale, 255.0), 0.0)
tl.store(token_scale_ptr + qblock_idx, encoded_scale.to(tl.uint8))
# Padding scale at index 7
tl.store(token_scale_ptr + 7, tl.zeros((), dtype=tl.uint8))
# ========== Store BF16 portion (last 64 elements, no quantization) ==========
bf16_input_offset = fp8_dim
# Process bf16 in chunks of 16
bf16_out_ptr = token_bf16_ptr.to(tl.pointer_type(tl.bfloat16))
for i in tl.static_range(bf16_dim // 16):
chunk_offsets = i * 16 + tl.arange(0, 16)
bf16_vals = tl.load(input_row_ptr + bf16_input_offset + chunk_offsets)
tl.store(bf16_out_ptr + chunk_offsets, bf16_vals)
def quantize_and_insert_k_cache(
k: torch.Tensor, # [num_tokens, 512] bf16
k_cache: torch.Tensor, # [num_blocks, block_bytes] uint8
slot_mapping: torch.Tensor, # [num_tokens] int64
block_size: int = 64,
is_ue8m0: bool = True,
):
"""
Quantize K tensor and insert into paged K cache.
K Cache block layout (block_size=64 tokens):
- First 64 * 576 = 36864 bytes: Token data
- Each token: 448 bytes (fp8) + 128 bytes (bf16)
- Next 64 * 8 = 512 bytes: Scales
- Each token: 8 bytes (uint8 scales, 7 real + 1 padding)
- Padded to multiple of 576
"""
assert k.dim() == 2 and k.shape[1] == 512, (
f"K must be [num_tokens, 512], got {k.shape}"
)
assert k.dtype == torch.bfloat16, f"K must be bf16, got {k.dtype}"
assert is_ue8m0, "Only support ue8m0 quantization."
# NOTE: When using DP, slot_mapping.shape[0] can be less than k.shape[0] due to
# padding. Always use slot_mapping.shape[0] as the token count.
num_tokens = slot_mapping.shape[0]
block_stride = k_cache.stride(0) # bytes per block
TOKEN_FP8_DIM = 448
TOKEN_BF16_DIM = 64
TOKEN_SCALE_DIM = 8
QUANT_BLOCK_SIZE = 64
FP8_MAX = 448.0
TOKEN_DATA_SIZE = TOKEN_FP8_DIM + TOKEN_BF16_DIM * 2
grid = (num_tokens,)
quantize_and_insert_k_kernel[grid](
k,
slot_mapping,
k_cache,
num_tokens,
input_dim=512,
fp8_dim=TOKEN_FP8_DIM,
bf16_dim=TOKEN_BF16_DIM,
scale_dim=TOKEN_SCALE_DIM,
quant_block=QUANT_BLOCK_SIZE,
cache_block_size=block_size,
token_data_size=TOKEN_DATA_SIZE,
block_stride=block_stride,
fp8_max=FP8_MAX,
n_quant_blocks=8,
)
@triton.jit
def _dequantize_and_gather_k_kernel(
out_ptr,
out_stride0,
out_stride1,
k_cache_ptr,
seq_lens_ptr,
block_table_ptr,
offset,
gather_lens_ptr,
# Constants
max_blocks_per_seq: tl.constexpr,
fp8_dim: tl.constexpr, # 448
bf16_dim: tl.constexpr, # 64
scale_dim: tl.constexpr, # 8
quant_block: tl.constexpr, # 64 (quantization block size)
cache_block_size: tl.constexpr, # 64 or 128 (paged cache block size)
token_data_size: tl.constexpr, # 576 bytes per token data
block_stride: tl.constexpr, # total bytes per block (padded) int32
output_dim: tl.constexpr, # 512
fp8_max: tl.constexpr,
n_quant_blocks: tl.constexpr, # 7 real blocks
):
batch_idx = tl.program_id(0)
worker_id = tl.program_id(1)
num_workers = tl.num_programs(1)
seq_len = tl.load(seq_lens_ptr + batch_idx)
if gather_lens_ptr is not None: # noqa: SIM108
gather_len = tl.load(gather_lens_ptr + batch_idx)
else:
# Gather all tokens
gather_len = seq_len
start_pos = seq_len - gather_len
for i in range(worker_id, gather_len, num_workers):
# Calculate the actual token index in the sequence
pos = start_pos + i
# Calculate which block and position within block
block_in_seq = pos // cache_block_size
pos_in_block = pos % cache_block_size
# Get physical block index from block table
block_table_row_ptr = block_table_ptr + batch_idx * max_blocks_per_seq
physical_block_idx = tl.load(block_table_row_ptr + block_in_seq) # int32
# int64: physical_block_idx * block_stride can exceed 2^31 with many
# KV-cache blocks (e.g. >= 57K at block_stride ~37K).
cache_block_ptr = k_cache_ptr + physical_block_idx.to(tl.int64) * block_stride
# Token data pointer
token_data_ptr = cache_block_ptr + pos_in_block * token_data_size
# Scale pointer: after all token data
token_scale_ptr = (
cache_block_ptr
+ cache_block_size * token_data_size
+ pos_in_block * scale_dim
)
# Token data layout: [0:448] fp8, [448:576] bf16
token_fp8_ptr = token_data_ptr
token_bf16_ptr = token_data_ptr + fp8_dim
# Output pointer for this token (flattened)
output_row_ptr = out_ptr + batch_idx * out_stride0 + (offset + i) * out_stride1
# ========== Dequantize FP8 portion using UE8M0 ==========
for qblock_idx in tl.static_range(n_quant_blocks):
qblock_start = qblock_idx * quant_block
if qblock_start < fp8_dim:
offsets = qblock_start + tl.arange(0, quant_block)
mask = offsets < fp8_dim
# Load quantized fp8 values (stored as uint8)
x_uint8 = tl.load(token_fp8_ptr + offsets, mask=mask, other=0)
# Bitcast uint8 back to fp8
x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)
# Convert fp8 to float32 for computation
x_float = x_fp8.to(tl.float32)
# Load and decode UE8M0 scale
# UE8M0: scale = 2^(stored_value - 127)
encoded_scale = tl.load(token_scale_ptr + qblock_idx)
exponent = encoded_scale.to(tl.float32) - 127.0
scale = tl.exp2(exponent)
# Dequantize: bf16_value = fp8_value * scale
x_dequant = x_float * scale
# Store as bf16
tl.store(output_row_ptr + offsets, x_dequant.to(tl.bfloat16), mask=mask)
# ========== Copy BF16 portion directly ==========
bf16_output_offset = fp8_dim # After 448 elements in output
# Read bf16 from cache
bf16_cache_ptr = token_bf16_ptr.to(tl.pointer_type(tl.bfloat16))
# Process in chunks of 16
for j in tl.static_range(bf16_dim // 16):
chunk_offsets = j * 16 + tl.arange(0, 16)
bf16_vals = tl.load(bf16_cache_ptr + chunk_offsets)
tl.store(output_row_ptr + bf16_output_offset + chunk_offsets, bf16_vals)
def dequantize_and_gather_k_cache(
# [num_reqs, max_num_tokens, head_size]
out: torch.Tensor,
# [num_blocks, block_size, head_bytes]
k_cache: torch.Tensor,
# [num_reqs]
seq_lens: torch.Tensor,
# [num_reqs]
gather_lens: torch.Tensor | None,
# [num_reqs, max_blocks_per_seq]
block_table: torch.Tensor,
block_size: int,
offset: int,
) -> None:
TOKEN_FP8_DIM = 448
TOKEN_BF16_DIM = 64
TOKEN_SCALE_DIM = 8
QUANT_BLOCK_SIZE = 64
FP8_MAX = 448.0
TOKEN_DATA_SIZE = TOKEN_FP8_DIM + TOKEN_BF16_DIM * 2
num_reqs = seq_lens.shape[0]
NUM_WORKERS = 128
_dequantize_and_gather_k_kernel[(num_reqs, NUM_WORKERS)](
out,
out.stride(0),
out.stride(1),
k_cache,
seq_lens,
block_table,
offset,
gather_lens,
max_blocks_per_seq=block_table.shape[-1],
fp8_dim=TOKEN_FP8_DIM,
bf16_dim=TOKEN_BF16_DIM,
scale_dim=TOKEN_SCALE_DIM,
quant_block=QUANT_BLOCK_SIZE,
cache_block_size=block_size,
token_data_size=TOKEN_DATA_SIZE,
block_stride=k_cache.stride(0),
output_dim=512,
fp8_max=FP8_MAX,
n_quant_blocks=7,
)
def compute_global_topk_indices_and_lens(
topk_indices: torch.Tensor,
token_to_req_indices: torch.Tensor,
block_table: torch.Tensor,
block_size: int,
is_valid_token: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Map local topk indices to global KV cache slots and count valid entries.
Fuses three operations into a single kernel:
1. Block-table lookup (local index → global slot id)
2. Valid-entry counting (topk_lens per token)
3. Masking padding tokens to length 0
"""
num_tokens = topk_indices.shape[0]
global_topk_indices = torch.empty_like(topk_indices)
topk_lens = torch.empty(num_tokens, dtype=torch.int32, device=topk_indices.device)
_compute_global_topk_indices_and_lens_kernel[(num_tokens,)](
global_topk_indices,
global_topk_indices.stride(0),
topk_lens,
topk_indices,
topk_indices.stride(0),
topk_indices.shape[-1],
token_to_req_indices,
block_table,
block_table.stride(0),
block_size,
is_valid_token,
TRITON_BLOCK_SIZE=1024,
)
return global_topk_indices, topk_lens
@triton.jit
def _compute_global_topk_indices_and_lens_kernel(
global_topk_indices_ptr,
global_topk_indices_stride,
topk_lens_ptr,
topk_indices_ptr,
topk_indices_stride,
topk,
token_to_req_indices_ptr,
block_table_ptr,
block_table_stride,
block_size,
is_valid_token_ptr,
TRITON_BLOCK_SIZE: tl.constexpr,
):
token_idx = tl.program_id(0)
is_valid_token = tl.load(is_valid_token_ptr + token_idx)
req_idx = tl.load(token_to_req_indices_ptr + token_idx)
count = tl.zeros((), dtype=tl.int32)
for i in range(0, topk, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
mask = offset < topk
local_idx = tl.load(
topk_indices_ptr + token_idx * topk_indices_stride + offset,
mask=mask,
other=-1,
)
is_valid = local_idx >= 0
block_indices = local_idx // block_size
block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices,
mask=mask & is_valid,
)
block_offsets = local_idx % block_size
slot_ids = block_numbers * block_size + block_offsets
slot_ids = tl.where(is_valid, slot_ids, -1)
tl.store(
global_topk_indices_ptr + token_idx * global_topk_indices_stride + offset,
slot_ids,
mask=mask,
)
count += tl.sum(is_valid.to(tl.int32), axis=0)
# Zero out length for padding tokens.
tl.store(topk_lens_ptr + token_idx, tl.where(is_valid_token, count, 0))
# FlashMLA sparse prefill asserts `params.topk % B_TOPK == 0` (see
# flashmla/csrc/sm100/prefill/sparse/fwd/head{64,128}/phase1.cuh). B_TOPK is
# 64 for the h_q=64 kernel and 128 for h_q=128; pad to 128 to satisfy both.
# The extra slots stay as -1 sentinels and `combined_lens` caps the valid
# range via `topk_length`, so padding is a no-op at kernel level.
_SPARSE_PREFILL_TOPK_ALIGNMENT = 128
def combine_topk_swa_indices(
topk_indices: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
gather_lens: torch.Tensor,
window_size: int,
compress_ratio: int,
topk: int,
M: int,
N: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = topk_indices.shape[0]
num_reqs = seq_lens.shape[0]
combined_topk = (
(topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1)
// _SPARSE_PREFILL_TOPK_ALIGNMENT
* _SPARSE_PREFILL_TOPK_ALIGNMENT
)
combined_indices = torch.full(
(num_tokens, combined_topk),
fill_value=-1,
dtype=torch.int32,
device=topk_indices.device,
)
combined_lens = torch.empty(
num_tokens, dtype=torch.int32, device=topk_indices.device
)
NUM_WORKERS = 128
_combine_topk_swa_indices_kernel[(num_reqs, NUM_WORKERS)](
combined_indices,
combined_indices.stride(0),
combined_lens,
topk_indices,
topk_indices.stride(0),
query_start_loc,
seq_lens,
gather_lens,
M,
N,
TOP_K=topk,
COMPRESS_RATIO=compress_ratio,
WINDOW_SIZE=window_size,
PADDED_TOP_K=triton.next_power_of_2(topk_indices.shape[-1]),
)
return combined_indices, combined_lens
@triton.jit
def _combine_topk_swa_indices_kernel(
combined_indices_ptr,
combined_indices_stride,
combined_lens_ptr,
topk_indices_ptr,
topk_indices_stride,
query_start_loc_ptr,
seq_lens_ptr,
gather_lens_ptr,
M,
N,
TOP_K: tl.constexpr,
COMPRESS_RATIO: tl.constexpr,
WINDOW_SIZE: tl.constexpr,
PADDED_TOP_K: tl.constexpr,
):
batch_idx = tl.program_id(0)
worker_id = tl.program_id(1)
num_workers = tl.num_programs(1)
# query_start_loc is a global tensor; rebase to chunk-local offsets
# by subtracting the chunk's starting value.
base = tl.load(query_start_loc_ptr)
query_start = tl.load(query_start_loc_ptr + batch_idx) - base
query_end = tl.load(query_start_loc_ptr + batch_idx + 1) - base
query_len = query_end - query_start
seq_len = tl.load(seq_lens_ptr + batch_idx)
gather_len = tl.load(gather_lens_ptr + batch_idx)
start_pos = seq_len - query_len
# The SWA portion of the gathered buffer starts from position
# (seq_len - gather_len), not position 0. We need this offset
# to correctly index into the gathered buffer.
gather_start = seq_len - gather_len
for token_idx in range(query_start + worker_id, query_end, num_workers):
# topk_len is fully determined by the query token's absolute position:
# both the C4A indexer and the C128A metadata builder emit
# min((pos + 1) // compress_ratio, topk_tokens) valid entries.
# Caller passes TOP_K=0 for SWA-only layers to zero this out.
token_idx_in_query = token_idx - query_start
pos = start_pos + token_idx_in_query
topk_len = tl.minimum((pos + 1) // COMPRESS_RATIO, TOP_K)
swa_len = tl.minimum(pos + 1, WINDOW_SIZE)
offset = tl.arange(0, PADDED_TOP_K)
mask = offset < topk_len
topk_indices = tl.load(
topk_indices_ptr + token_idx * topk_indices_stride + offset,
mask=mask,
)
tl.store(
combined_indices_ptr + token_idx * combined_indices_stride + offset,
topk_indices + M * batch_idx,
mask=mask,
)
offset = tl.arange(0, WINDOW_SIZE)
# Index into gathered buffer: N + (position - gather_start)
# For positions [pos - swa_len + 1, pos], the buffer indices are:
# [N + pos - swa_len + 1 - gather_start, N + pos - gather_start]
tl.store(
combined_indices_ptr
+ token_idx * combined_indices_stride
+ topk_len
+ offset,
M * batch_idx + N + offset + pos - swa_len + 1 - gather_start,
mask=offset < swa_len,
)
combined_len = topk_len + swa_len
tl.store(combined_lens_ptr + token_idx, combined_len)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Fused compressor + FP8/MXFP4 UE8M0 quantization + KV cache insert kernels.
Three specialized kernels:
- _fused_kv_compress_norm_rope_insert_sparse_attn:
head=512, nope=448 FP8 + rope=64 bf16
- _fused_kv_compress_norm_rope_insert_indexer_attn:
head=128, all FP8, 1 block/token
- _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn:
head=128, MXFP4 (block=32), 4 ue8m0 bytes
RoPE is register-based via tl.reshape -> tl.split -> tl.interleave (or the
even/odd halves are consumed directly for MXFP4, no interleave needed).
FP8 UE8M0 quant uses tl.reshape to tile [N_QUANT_BLOCKS, QUANT_BLOCK] for
per-block absmax entirely in registers. MXFP4 does the same tiling on the
even/odd halves, producing (N_QUANT_BLOCKS, MXFP4_BLOCK/2) packed nibbles
and N_QUANT_BLOCKS ue8m0 bytes.
"""
from vllm.triton_utils import tl, triton
from .fused_indexer_q import _e2m1_nibble
# =============================================================================
# DeepseekV4 Attention path (head=512, nope=448 FP8 + rope=64 bf16)
# =============================================================================
@triton.jit
def _fused_kv_compress_norm_rope_insert_sparse_attn(
# ── state cache (compressor internal state) ──
state_cache_ptr,
state_cache_stride0,
state_cache_stride1,
# ── metadata ──
token_to_req_indices_ptr,
positions_ptr,
slot_mapping_ptr,
block_table_ptr,
block_table_stride,
block_size,
# ── RMSNorm ──
rms_norm_weight_ptr,
rms_norm_eps,
# ── RoPE ──
cos_sin_cache_ptr,
cos_sin_stride,
# ── KV cache output ──
k_cache_ptr,
kv_slot_mapping_ptr,
kv_cache_block_size,
# ── constexprs ──
HEAD_SIZE: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
STATE_WIDTH: tl.constexpr,
COMPRESS_RATIO: tl.constexpr,
OVERLAP: tl.constexpr,
ROPE_HEAD_DIM: tl.constexpr,
FP8_MAX: tl.constexpr, # 448.0
QUANT_BLOCK: tl.constexpr, # 64 for DeepseekV4
TOKEN_STRIDE: tl.constexpr, # 576 for DeepseekV4
SCALE_DIM: tl.constexpr, # 8 for DeepseekV4 (7 real + 1 pad)
KV_BLOCK_STRIDE: tl.constexpr,
):
"""Fused compress → RMSNorm → FP8 quant (nope) → RoPE → bf16 store (rope).
One program per token; early-exits for non-boundary positions.
Cache block layout (``block_size`` tokens):
[0, bs*576): token data (448 fp8 + 128 bf16 each)
[bs*576, +bs*8): uint8 UE8M0 scales (7 real + 1 pad each)
"""
token_idx = tl.program_id(0)
slot_id = tl.load(slot_mapping_ptr + token_idx)
if slot_id < 0:
return
position = tl.load(positions_ptr + token_idx)
if (position + 1) % COMPRESS_RATIO != 0:
return
req_idx = tl.load(token_to_req_indices_ptr + token_idx)
# ── Gather state cache entries ────────────────────────────────────
start = position - (1 + OVERLAP) * COMPRESS_RATIO + 1
tokens = tl.arange(0, (1 + OVERLAP) * COMPRESS_RATIO)
pos = start + tokens
mask_pos = pos >= 0
block_indices = pos // block_size
block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices,
mask=mask_pos,
other=0,
)
block_offsets = pos % block_size
head_offset = (tokens >= COMPRESS_RATIO).to(tl.int32) * HEAD_SIZE
block = tl.arange(0, TRITON_BLOCK_SIZE)
mask = block < HEAD_SIZE
block_numbers_i64 = block_numbers.to(tl.int64)
# Precomputed row base shared by score and kv loads
row_base = (
state_cache_ptr
+ block_numbers_i64 * state_cache_stride0
+ block_offsets * state_cache_stride1
+ head_offset
)
combined_mask = mask_pos[:, None] & mask[None, :]
# ── Softmax + weighted sum ───────────────────────────────────────
score = tl.load(
row_base[:, None] + STATE_WIDTH + block[None, :],
mask=combined_mask,
other=float("-inf"),
)
score = tl.softmax(score, dim=0)
kv = tl.load(
row_base[:, None] + block[None, :],
mask=combined_mask,
other=0.0,
)
compressed_kv = tl.sum(kv * score, axis=0) # [TRITON_BLOCK_SIZE] fp32
# ── RMSNorm (fp32 throughout) ──────────────────────────────────────
rms_w = tl.load(rms_norm_weight_ptr + block, mask=mask, other=0.0)
variance = tl.sum(compressed_kv * compressed_kv, axis=0) / HEAD_SIZE
rrms = tl.rsqrt(variance + rms_norm_eps)
normed = compressed_kv * rrms * rms_w
# ── KV cache pointers ────────────────────────────────────────────
kv_slot_idx = tl.load(kv_slot_mapping_ptr + token_idx)
if kv_slot_idx < 0:
return
kv_block_idx = kv_slot_idx // kv_cache_block_size
kv_pos_in_block = kv_slot_idx % kv_cache_block_size
cache_block_ptr = k_cache_ptr + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE
fp8_ptr = cache_block_ptr + kv_pos_in_block * TOKEN_STRIDE
scale_ptr = (
cache_block_ptr
+ kv_cache_block_size * TOKEN_STRIDE
+ kv_pos_in_block * SCALE_DIM
)
NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM # 448
HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 # 32
# FP8 UE8M0 quant: cast fp32 → bf16 → fp32 before quant to match reference.
N_QUANT_BLOCKS: tl.constexpr = TRITON_BLOCK_SIZE // QUANT_BLOCK
N_NOPE_BLOCKS: tl.constexpr = NOPE_HEAD_DIM // QUANT_BLOCK # 7
INV_FP8_MAX: tl.constexpr = 1.0 / FP8_MAX
quant_input = normed.to(tl.bfloat16).to(tl.float32)
quant_2d = tl.reshape(quant_input, (N_QUANT_BLOCKS, QUANT_BLOCK))
abs_2d = tl.abs(quant_2d)
block_absmax = tl.max(abs_2d, axis=1) # [N_QUANT_BLOCKS] fp32
block_absmax = tl.maximum(block_absmax, 1e-4)
raw_scales = block_absmax * INV_FP8_MAX
exponents = tl.ceil(tl.log2(raw_scales))
inv_scales = tl.exp2(-exponents)
inv_scales_col = tl.reshape(inv_scales, (N_QUANT_BLOCKS, 1))
x_scaled = quant_2d * inv_scales_col
x_clamped = tl.clamp(x_scaled, -FP8_MAX, FP8_MAX)
x_fp8 = x_clamped.to(tl.float8e4nv)
x_uint8 = x_fp8.to(tl.uint8, bitcast=True)
x_uint8_flat = tl.reshape(x_uint8, (TRITON_BLOCK_SIZE,))
nope_mask = block < NOPE_HEAD_DIM
tl.store(fp8_ptr + block, x_uint8_flat, mask=nope_mask)
scale_idx = tl.arange(0, N_QUANT_BLOCKS)
encoded = exponents + 127.0
encoded = tl.maximum(tl.minimum(encoded, 255.0), 0.0)
tl.store(
scale_ptr + scale_idx,
encoded.to(tl.uint8),
mask=scale_idx < N_NOPE_BLOCKS,
)
tl.store(scale_ptr + N_NOPE_BLOCKS, tl.zeros((), dtype=tl.uint8))
# Register-based GPT-J RoPE in fp32.
NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2
NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2
pair_2d = tl.reshape(normed, (NUM_PAIRS, 2))
even, odd = tl.split(pair_2d) # each [NUM_PAIRS] fp32
pair_idx = tl.arange(0, NUM_PAIRS)
rope_pair_local = pair_idx - NOPE_PAIRS
is_rope_pair = rope_pair_local >= 0
cs_idx = tl.maximum(rope_pair_local, 0)
compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO
cache_base = cos_sin_cache_ptr + compressed_pos * cos_sin_stride
cos_v = tl.load(cache_base + cs_idx, mask=is_rope_pair, other=1.0)
sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope_pair, other=0.0)
new_even = even * cos_v - odd * sin_v
new_odd = odd * cos_v + even * sin_v
result = tl.interleave(new_even, new_odd) # [TRITON_BLOCK_SIZE] fp32
# Store rotated rope portion as bf16 into the cache's bf16 area.
bf16_ptr = (fp8_ptr + NOPE_HEAD_DIM).to(tl.pointer_type(tl.bfloat16))
rope_local = block - NOPE_HEAD_DIM
is_rope = (block >= NOPE_HEAD_DIM) & mask
tl.store(bf16_ptr + rope_local, result.to(tl.bfloat16), mask=is_rope)
# =============================================================================
# Indexer path (head=128, all FP8, single quant block)
# =============================================================================
@triton.jit
def _fused_kv_compress_norm_rope_insert_indexer_attn(
# ── state cache (compressor internal state) ──
state_cache_ptr,
state_cache_stride0,
state_cache_stride1,
# ── metadata ──
token_to_req_indices_ptr,
positions_ptr,
slot_mapping_ptr,
block_table_ptr,
block_table_stride,
block_size,
# ── RMSNorm ──
rms_norm_weight_ptr,
rms_norm_eps,
# ── RoPE ──
cos_sin_cache_ptr,
cos_sin_stride,
# ── KV cache output ──
k_cache_ptr,
kv_slot_mapping_ptr,
kv_cache_block_size,
# ── constexprs ──
HEAD_SIZE: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
STATE_WIDTH: tl.constexpr,
COMPRESS_RATIO: tl.constexpr,
OVERLAP: tl.constexpr,
ROPE_HEAD_DIM: tl.constexpr,
FP8_MAX: tl.constexpr, # 448.0
QUANT_BLOCK: tl.constexpr, # 128 for indexer
TOKEN_STRIDE: tl.constexpr, # 128 for indexer
SCALE_DIM: tl.constexpr, # 4 for indexer (1 float32)
KV_BLOCK_STRIDE: tl.constexpr,
):
"""Fused compress → RMSNorm → RoPE → FP8 quant → store.
One program per token; early-exits for non-boundary positions.
Cache block layout:
[0, bs*128): FP8 data (128 bytes/token)
[bs*128, +bs*4): float32 scales (4 bytes/token)
For head_dim=128 we have exactly one quant block, so we skip the
[N_QUANT_BLOCKS, QUANT_BLOCK] reshape entirely and use a flat
``tl.max`` reduction.
"""
token_idx = tl.program_id(0)
slot_id = tl.load(slot_mapping_ptr + token_idx)
if slot_id < 0:
return
position = tl.load(positions_ptr + token_idx)
if (position + 1) % COMPRESS_RATIO != 0:
return
req_idx = tl.load(token_to_req_indices_ptr + token_idx)
# ── Gather state cache entries ────────────────────────────────────
start = position - (1 + OVERLAP) * COMPRESS_RATIO + 1
tokens = tl.arange(0, (1 + OVERLAP) * COMPRESS_RATIO)
pos = start + tokens
mask_pos = pos >= 0
block_indices = pos // block_size
block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices,
mask=mask_pos,
other=0,
)
block_offsets = pos % block_size
head_offset = (tokens >= COMPRESS_RATIO).to(tl.int32) * HEAD_SIZE
block = tl.arange(0, TRITON_BLOCK_SIZE)
mask = block < HEAD_SIZE
block_numbers_i64 = block_numbers.to(tl.int64)
row_base = (
state_cache_ptr
+ block_numbers_i64 * state_cache_stride0
+ block_offsets * state_cache_stride1
+ head_offset
)
combined_mask = mask_pos[:, None] & mask[None, :]
score = tl.load(
row_base[:, None] + STATE_WIDTH + block[None, :],
mask=combined_mask,
other=float("-inf"),
)
score = tl.softmax(score, dim=0)
kv = tl.load(
row_base[:, None] + block[None, :],
mask=combined_mask,
other=0.0,
)
compressed_kv = tl.sum(kv * score, axis=0) # [TRITON_BLOCK_SIZE] fp32
# ── RMSNorm (fp32 throughout) ──────────────────────────────────────
rms_w = tl.load(rms_norm_weight_ptr + block, mask=mask, other=0.0)
variance = tl.sum(compressed_kv * compressed_kv, axis=0) / HEAD_SIZE
rrms = tl.rsqrt(variance + rms_norm_eps)
normed = compressed_kv * rrms * rms_w
# ── KV cache pointers ────────────────────────────────────────────
kv_slot_idx = tl.load(kv_slot_mapping_ptr + token_idx)
if kv_slot_idx < 0:
return
kv_block_idx = kv_slot_idx // kv_cache_block_size
kv_pos_in_block = kv_slot_idx % kv_cache_block_size
cache_block_ptr = k_cache_ptr + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE
fp8_ptr = cache_block_ptr + kv_pos_in_block * TOKEN_STRIDE
scale_ptr = (
cache_block_ptr
+ kv_cache_block_size * TOKEN_STRIDE
+ kv_pos_in_block * SCALE_DIM
)
NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM
HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2
# ── Register-based GPT-J forward RoPE in fp32 ─────────────────────
NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2
NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2
normed_2d = tl.reshape(normed, (NUM_PAIRS, 2))
even, odd = tl.split(normed_2d) # each [NUM_PAIRS] fp32
pair_idx = tl.arange(0, NUM_PAIRS)
rope_pair_local = pair_idx - NOPE_PAIRS
is_rope_pair = rope_pair_local >= 0
cs_idx = tl.maximum(rope_pair_local, 0)
compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO
cache_base = cos_sin_cache_ptr + compressed_pos * cos_sin_stride
cos_v = tl.load(cache_base + cs_idx, mask=is_rope_pair, other=1.0)
sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope_pair, other=0.0)
new_even = even * cos_v - odd * sin_v
new_odd = odd * cos_v + even * sin_v
result = tl.interleave(new_even, new_odd) # fp32
# ── FP8 UE8M0 quant: single block, flat reduction ────────────────
tl.static_assert(
TRITON_BLOCK_SIZE == QUANT_BLOCK,
"Indexer expects one quant block (QUANT_BLOCK == TRITON_BLOCK_SIZE)",
)
INV_FP8_MAX: tl.constexpr = 1.0 / FP8_MAX
result_bf16 = result.to(tl.bfloat16).to(tl.float32)
absmax = tl.max(tl.abs(result_bf16), axis=0) # scalar
absmax = tl.maximum(absmax, 1e-4)
raw_scale = absmax * INV_FP8_MAX
exponent = tl.ceil(tl.log2(raw_scale))
inv_scale = tl.exp2(-exponent)
x_scaled = result_bf16 * inv_scale
x_clamped = tl.clamp(x_scaled, -FP8_MAX, FP8_MAX)
x_fp8 = x_clamped.to(tl.float8e4nv)
x_uint8 = x_fp8.to(tl.uint8, bitcast=True)
tl.store(fp8_ptr + block, x_uint8, mask=mask)
# Single float32 scale
scale_val = tl.exp2(exponent)
tl.store(scale_ptr.to(tl.pointer_type(tl.float32)), scale_val)
# =============================================================================
# Indexer path (head=128, MXFP4: 2 nibbles/byte + ue8m0 per 32-elem block)
# =============================================================================
@triton.jit
def _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn(
# ── state cache (compressor internal state) ──
state_cache_ptr,
state_cache_stride0,
state_cache_stride1,
# ── metadata ──
token_to_req_indices_ptr,
positions_ptr,
slot_mapping_ptr,
block_table_ptr,
block_table_stride,
block_size,
# ── RMSNorm ──
rms_norm_weight_ptr,
rms_norm_eps,
# ── RoPE ──
cos_sin_cache_ptr,
cos_sin_stride,
# ── KV cache output ──
k_cache_ptr,
kv_slot_mapping_ptr,
kv_cache_block_size,
# ── constexprs ──
HEAD_SIZE: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
STATE_WIDTH: tl.constexpr,
COMPRESS_RATIO: tl.constexpr,
OVERLAP: tl.constexpr,
ROPE_HEAD_DIM: tl.constexpr,
FP8_MAX: tl.constexpr, # unused for MXFP4 (kept for signature parity)
QUANT_BLOCK: tl.constexpr, # 32 for MXFP4
TOKEN_STRIDE: tl.constexpr, # HEAD_SIZE // 2 = 64 packed bytes/token
SCALE_DIM: tl.constexpr, # HEAD_SIZE // QUANT_BLOCK = 4 ue8m0 bytes/token
KV_BLOCK_STRIDE: tl.constexpr,
):
"""Fused compress → RMSNorm → RoPE → MXFP4 quant → store.
One program per token; early-exits for non-boundary positions.
Cache block layout (``block_size`` tokens per cache block):
[0, bs*TOKEN_STRIDE): packed MXFP4 nibbles (2 values/byte)
[bs*TOKEN_STRIDE, +bs*SCALE_DIM): ue8m0 scale bytes (one per 32-elem block)
MXFP4 format:
- E2M1 4-bit values packed two per byte (low nibble first, then high).
- Per-32-element block scale = 2^ceil(log2(amax / 6.0)), stored ue8m0
(byte = exponent + 127).
- Max representable magnitude = 6.0.
"""
token_idx = tl.program_id(0)
slot_id = tl.load(slot_mapping_ptr + token_idx)
if slot_id < 0:
return
position = tl.load(positions_ptr + token_idx)
if (position + 1) % COMPRESS_RATIO != 0:
return
req_idx = tl.load(token_to_req_indices_ptr + token_idx)
# ── Gather state cache entries ────────────────────────────────────
start = position - (1 + OVERLAP) * COMPRESS_RATIO + 1
tokens = tl.arange(0, (1 + OVERLAP) * COMPRESS_RATIO)
pos = start + tokens
mask_pos = pos >= 0
block_indices = pos // block_size
block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices,
mask=mask_pos,
other=0,
)
block_offsets = pos % block_size
head_offset = (tokens >= COMPRESS_RATIO).to(tl.int32) * HEAD_SIZE
block = tl.arange(0, TRITON_BLOCK_SIZE)
mask = block < HEAD_SIZE
block_numbers_i64 = block_numbers.to(tl.int64)
row_base = (
state_cache_ptr
+ block_numbers_i64 * state_cache_stride0
+ block_offsets * state_cache_stride1
+ head_offset
)
combined_mask = mask_pos[:, None] & mask[None, :]
score = tl.load(
row_base[:, None] + STATE_WIDTH + block[None, :],
mask=combined_mask,
other=float("-inf"),
)
score = tl.softmax(score, dim=0)
kv = tl.load(
row_base[:, None] + block[None, :],
mask=combined_mask,
other=0.0,
)
compressed_kv = tl.sum(kv * score, axis=0) # [TRITON_BLOCK_SIZE] fp32
# ── RMSNorm (fp32 throughout) ──────────────────────────────────────
rms_w = tl.load(rms_norm_weight_ptr + block, mask=mask, other=0.0)
variance = tl.sum(compressed_kv * compressed_kv, axis=0) / HEAD_SIZE
rrms = tl.rsqrt(variance + rms_norm_eps)
normed = compressed_kv * rrms * rms_w
# ── KV cache pointers (segregated: values first, then scales) ────
kv_slot_idx = tl.load(kv_slot_mapping_ptr + token_idx)
if kv_slot_idx < 0:
return
kv_block_idx = kv_slot_idx // kv_cache_block_size
kv_pos_in_block = kv_slot_idx % kv_cache_block_size
cache_block_ptr = k_cache_ptr + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE
val_ptr = cache_block_ptr + kv_pos_in_block * TOKEN_STRIDE
scale_ptr = (
cache_block_ptr
+ kv_cache_block_size * TOKEN_STRIDE
+ kv_pos_in_block * SCALE_DIM
)
NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM
HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2
# ── Register-based GPT-J forward RoPE in fp32 ─────────────────────
# We keep the even/odd halves (no tl.interleave afterwards) because the
# MXFP4 per-block absmax / pack naturally operates on (even, odd) pairs.
NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2
NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2
normed_2d = tl.reshape(normed, (NUM_PAIRS, 2))
even, odd = tl.split(normed_2d) # each [NUM_PAIRS] fp32
pair_idx = tl.arange(0, NUM_PAIRS)
rope_pair_local = pair_idx - NOPE_PAIRS
is_rope_pair = rope_pair_local >= 0
cs_idx = tl.maximum(rope_pair_local, 0)
compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO
cache_base = cos_sin_cache_ptr + compressed_pos * cos_sin_stride
cos_v = tl.load(cache_base + cs_idx, mask=is_rope_pair, other=1.0)
sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope_pair, other=0.0)
new_even = even * cos_v - odd * sin_v
new_odd = odd * cos_v + even * sin_v
# bf16 roundtrip for parity with reference / Q-side kernel numerics.
new_even = new_even.to(tl.bfloat16).to(tl.float32)
new_odd = new_odd.to(tl.bfloat16).to(tl.float32)
# ── MXFP4 quant: tile even/odd halves into (N_BLOCKS, HALF_BLOCK) ──
# Each MXFP4 block of QUANT_BLOCK elements = HALF_BLOCK consecutive pairs,
# so (N_BLOCKS, HALF_BLOCK) rows of even/odd each land exactly one block.
N_QUANT_BLOCKS: tl.constexpr = HEAD_SIZE // QUANT_BLOCK
HALF_BLOCK: tl.constexpr = QUANT_BLOCK // 2
tl.static_assert(TRITON_BLOCK_SIZE == HEAD_SIZE)
tl.static_assert(HEAD_SIZE % QUANT_BLOCK == 0)
tl.static_assert(TOKEN_STRIDE == HEAD_SIZE // 2)
tl.static_assert(SCALE_DIM == N_QUANT_BLOCKS)
even_2d = tl.reshape(new_even, (N_QUANT_BLOCKS, HALF_BLOCK))
odd_2d = tl.reshape(new_odd, (N_QUANT_BLOCKS, HALF_BLOCK))
amax = tl.maximum(
tl.max(tl.abs(even_2d), axis=1),
tl.max(tl.abs(odd_2d), axis=1),
)
amax = tl.maximum(amax, 1e-4)
# ue8m0 block scale: 2^ceil(log2(amax / 6.0)), stored as (exp + 127) byte.
log2_ratio = tl.ceil(tl.log2(amax / 6.0))
log2_ratio = tl.minimum(tl.maximum(log2_ratio, -127.0), 127.0)
inv_scale = tl.exp2(-log2_ratio)
ue8m0 = (log2_ratio + 127.0).to(tl.uint8) # [N_QUANT_BLOCKS]
inv_scale_col = tl.reshape(inv_scale, (N_QUANT_BLOCKS, 1))
lo_nib = _e2m1_nibble(even_2d * inv_scale_col) # (N_BLOCKS, HALF_BLOCK) uint8
hi_nib = _e2m1_nibble(odd_2d * inv_scale_col)
packed = lo_nib | (hi_nib << 4)
packed_flat = tl.reshape(packed, (TOKEN_STRIDE,))
tl.store(val_ptr + tl.arange(0, TOKEN_STRIDE), packed_flat)
tl.store(scale_ptr + tl.arange(0, SCALE_DIM), ue8m0)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
# MXFP4: 32 elements per block, packed 2 nibbles per byte, ue8m0 block scale.
MXFP4_BLOCK_SIZE = 32
@triton.jit
def _get_cos_sin(
cos_sin_cache_ptr,
cos_sin_cache_stride,
pos,
HALF_ROT_DIM: tl.constexpr,
):
block = tl.arange(0, HALF_ROT_DIM)
cos = tl.load(cos_sin_cache_ptr + pos * cos_sin_cache_stride + block)
cos = cos.to(tl.float32)
sin = tl.load(cos_sin_cache_ptr + pos * cos_sin_cache_stride + block + HALF_ROT_DIM)
sin = sin.to(tl.float32)
return cos, sin
@triton.jit
def _e2m1_nibble(x):
"""Quantize fp32 x (already scale-divided) to E2M1 4-bit nibble in uint8.
Matches torch.bucketize with boundaries
[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0] and right=False (each boundary
belongs to the lower bucket), plus sign bit."""
abs_x = tl.minimum(tl.abs(x), 6.0)
code = tl.where(
abs_x <= 0.25,
0.0,
tl.where(
abs_x <= 0.75,
1.0,
tl.where(
abs_x <= 1.25,
2.0,
tl.where(
abs_x <= 1.75,
3.0,
tl.where(
abs_x <= 2.5,
4.0,
tl.where(abs_x <= 3.5, 5.0, tl.where(abs_x <= 5.0, 6.0, 7.0)),
),
),
),
),
)
code_u8 = code.to(tl.uint8)
sign = ((x < 0) & (code_u8 != 0)).to(tl.uint8)
return code_u8 | (sign << 3)
@triton.jit
def _quantize_mxfp4_pair(x_lo, x_hi):
"""Quantize a block of MXFP4_BLOCK_SIZE fp32 values given as two
interleaved halves (x_lo = values at even positions in the block,
x_hi = values at odd positions). Returns:
- packed : uint8[BLOCK/2] (low nibble = quant(x_lo), high = quant(x_hi))
- ue8m0 : scalar uint8 (block scale = 2^(ue8m0 - 127))
"""
amax = tl.maximum(tl.max(tl.abs(x_lo)), tl.max(tl.abs(x_hi)))
amax = tl.maximum(amax, 1e-4)
# ue8m0 block scale: 2^ceil(log2(amax/6.0)).
log2_ratio = tl.math.ceil(tl.math.log2(amax / 6.0))
log2_ratio = tl.minimum(tl.maximum(log2_ratio, -127.0), 127.0)
scale = tl.math.exp2(log2_ratio)
ue8m0 = (log2_ratio + 127.0).to(tl.uint8)
inv_scale = 1.0 / scale
lo_nib = _e2m1_nibble(x_lo * inv_scale)
hi_nib = _e2m1_nibble(x_hi * inv_scale)
packed = lo_nib | (hi_nib << 4)
return packed, ue8m0
@triton.jit
def _fused_indexer_q_rope_quant_kernel(
pos_ptr,
# Index Q RoPE
index_q_ptr,
index_q_stride0,
index_q_stride1,
index_q_cos_sin_ptr,
index_q_cos_sin_stride,
INDEX_Q_HALF_ROT_DIM: tl.constexpr,
# Index Q Quantize
index_q_fp8_ptr,
index_q_fp8_stride0,
index_q_fp8_stride1,
INDEX_Q_HEAD_DIM: tl.constexpr,
# Index weights
index_weights_ptr,
index_weights_stride,
index_weights_softmax_scale,
index_weights_head_scale,
index_weights_out_ptr,
index_weights_out_stride,
):
# Layout matches the unfused reference (DeepseekV4ScalingRotaryEmbedding
# + per_token_group_quant_fp8): GPT-J interleaved RoPE applied to the
# LAST rope_dim dims of each head; the leading [0, NOPE_DIM) is passed
# through unchanged.
INDEX_Q_ROT_DIM: tl.constexpr = 2 * INDEX_Q_HALF_ROT_DIM
INDEX_Q_NOPE_DIM: tl.constexpr = INDEX_Q_HEAD_DIM - INDEX_Q_ROT_DIM
tl.static_assert(INDEX_Q_NOPE_DIM >= 0)
tok_idx = tl.program_id(0)
head_idx = tl.program_id(1)
pos = tl.load(pos_ptr + tok_idx)
cos, sin = _get_cos_sin(
index_q_cos_sin_ptr,
index_q_cos_sin_stride,
pos,
INDEX_Q_HALF_ROT_DIM,
)
half_offset = tl.arange(0, INDEX_Q_HALF_ROT_DIM)
base_ptr = index_q_ptr + tok_idx * index_q_stride0 + head_idx * index_q_stride1
# Interleaved (GPT-J) RoPE on dims [NOPE_DIM, HEAD_DIM):
# even = q[NOPE_DIM + 2*i], odd = q[NOPE_DIM + 2*i + 1]
rot_base = base_ptr + INDEX_Q_NOPE_DIM
x_even = tl.load(rot_base + half_offset * 2).to(tl.float32)
x_odd = tl.load(rot_base + half_offset * 2 + 1).to(tl.float32)
r_even = x_even * cos - x_odd * sin
r_odd = x_odd * cos + x_even * sin
# Match reference numerics: fp32 → bf16 → fp32 before the ue8m0 absmax.
# Same pattern as the K-side compressor kernel (fused_compress_quant_cache.py).
r_even = r_even.to(tl.bfloat16).to(tl.float32)
r_odd = r_odd.to(tl.bfloat16).to(tl.float32)
amax = tl.maximum(tl.max(tl.abs(r_even)), tl.max(tl.abs(r_odd)))
if INDEX_Q_NOPE_DIM > 0:
nope_offset = tl.arange(0, INDEX_Q_NOPE_DIM)
x_nope = tl.load(base_ptr + nope_offset).to(tl.float32)
amax = tl.maximum(amax, tl.max(tl.abs(x_nope)))
index_q_scale = tl.div_rn(tl.maximum(amax, 1e-4), 448.0)
index_q_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(index_q_scale)))
# Store quantized values to index_q_fp8
fp8_base_ptr = (
index_q_fp8_ptr + tok_idx * index_q_fp8_stride0 + head_idx * index_q_fp8_stride1
)
if INDEX_Q_NOPE_DIM > 0:
tl.store(
fp8_base_ptr + nope_offset,
tl.div_rn(x_nope, index_q_scale).to(tl.float8e4nv),
)
fp8_rot_base = fp8_base_ptr + INDEX_Q_NOPE_DIM
tl.store(
fp8_rot_base + half_offset * 2,
tl.div_rn(r_even, index_q_scale).to(tl.float8e4nv),
)
tl.store(
fp8_rot_base + half_offset * 2 + 1,
tl.div_rn(r_odd, index_q_scale).to(tl.float8e4nv),
)
# FP8 weight-fold contract:
# index_weights_out = index_weights * q_scale * softmax_scale * head_scale
# The per-token-per-head q_scale (fp32) IS folded into the output weights
# here because FP8 Q is stored WITHOUT a companion scale tensor — the
# downstream fp8_fp4_mqa_logits/fp8_fp4_paged_mqa_logits kernels use `weights` to
# apply per-token Q scale inline. See the MXFP4 kernel below for the
# contrasting convention (scales live with the Q values, weights are NOT
# q-scaled).
index_weights = tl.load(
index_weights_ptr + tok_idx * index_weights_stride + head_idx
)
index_weights = index_weights.to(tl.float32)
index_weights *= index_q_scale
index_weights *= index_weights_softmax_scale
index_weights *= index_weights_head_scale
tl.store(
index_weights_out_ptr + tok_idx * index_weights_out_stride + head_idx,
index_weights,
)
@triton.jit
def _fused_indexer_q_rope_mxfp4_kernel(
pos_ptr,
# Index Q RoPE input (fp/bf16)
index_q_ptr,
index_q_stride0,
index_q_stride1,
index_q_cos_sin_ptr,
index_q_cos_sin_stride,
INDEX_Q_HALF_ROT_DIM: tl.constexpr,
# MXFP4 Q outputs
index_q_mxfp4_ptr, # uint8, (T, H, HEAD_DIM // 2)
index_q_mxfp4_stride0,
index_q_mxfp4_stride1,
index_q_scale_ptr, # uint8 ue8m0, (T, H, HEAD_DIM // BLOCK)
index_q_scale_stride0,
index_q_scale_stride1,
INDEX_Q_HEAD_DIM: tl.constexpr,
MXFP4_BLOCK: tl.constexpr,
# Weights (NO per-token q_scale fold for MXFP4; per-block scales stay
# with the Q values in the output scale tensor).
index_weights_ptr,
index_weights_stride,
index_weights_softmax_scale,
index_weights_head_scale,
index_weights_out_ptr,
index_weights_out_stride,
):
INDEX_Q_ROT_DIM: tl.constexpr = 2 * INDEX_Q_HALF_ROT_DIM
INDEX_Q_NOPE_DIM: tl.constexpr = INDEX_Q_HEAD_DIM - INDEX_Q_ROT_DIM
NUM_NOPE_BLOCKS: tl.constexpr = INDEX_Q_NOPE_DIM // MXFP4_BLOCK
NUM_ROPE_BLOCKS: tl.constexpr = INDEX_Q_ROT_DIM // MXFP4_BLOCK
HALF_BLOCK: tl.constexpr = MXFP4_BLOCK // 2
tl.static_assert(INDEX_Q_NOPE_DIM >= 0)
tl.static_assert(INDEX_Q_NOPE_DIM % MXFP4_BLOCK == 0)
tl.static_assert(INDEX_Q_ROT_DIM % MXFP4_BLOCK == 0)
tl.static_assert(MXFP4_BLOCK % 2 == 0)
tok_idx = tl.program_id(0)
head_idx = tl.program_id(1)
pos = tl.load(pos_ptr + tok_idx)
q_base = index_q_ptr + tok_idx * index_q_stride0 + head_idx * index_q_stride1
out_base = (
index_q_mxfp4_ptr
+ tok_idx * index_q_mxfp4_stride0
+ head_idx * index_q_mxfp4_stride1
)
scale_base = (
index_q_scale_ptr
+ tok_idx * index_q_scale_stride0
+ head_idx * index_q_scale_stride1
)
half_off = tl.arange(0, HALF_BLOCK)
# ---- NoPE blocks: direct load, pair as (even-index, odd-index) values ----
for b in tl.static_range(NUM_NOPE_BLOCKS):
base = b * MXFP4_BLOCK
x_lo = tl.load(q_base + base + half_off * 2).to(tl.float32)
x_hi = tl.load(q_base + base + half_off * 2 + 1).to(tl.float32)
packed, ue8m0 = _quantize_mxfp4_pair(x_lo, x_hi)
tl.store(out_base + base // 2 + half_off, packed)
tl.store(scale_base + b, ue8m0)
# ---- RoPE blocks: apply GPT-J interleaved RoPE to the block's 16 pairs,
# then quantize. Each block covers HALF_BLOCK (=16) cos/sin pairs. ----
rot_q_base = q_base + INDEX_Q_NOPE_DIM
for b in tl.static_range(NUM_ROPE_BLOCKS):
pair_off = b * HALF_BLOCK + half_off # indices in [0, HALF_ROT_DIM)
cos_b = tl.load(
index_q_cos_sin_ptr + pos * index_q_cos_sin_stride + pair_off
).to(tl.float32)
sin_b = tl.load(
index_q_cos_sin_ptr
+ pos * index_q_cos_sin_stride
+ pair_off
+ INDEX_Q_HALF_ROT_DIM
).to(tl.float32)
x_even = tl.load(rot_q_base + pair_off * 2).to(tl.float32)
x_odd = tl.load(rot_q_base + pair_off * 2 + 1).to(tl.float32)
r_even = x_even * cos_b - x_odd * sin_b
r_odd = x_odd * cos_b + x_even * sin_b
# bf16 roundtrip for parity with the FP8 kernel / reference numerics.
r_even = r_even.to(tl.bfloat16).to(tl.float32)
r_odd = r_odd.to(tl.bfloat16).to(tl.float32)
packed, ue8m0 = _quantize_mxfp4_pair(r_even, r_odd)
rope_byte_off = (INDEX_Q_NOPE_DIM + b * MXFP4_BLOCK) // 2
tl.store(out_base + rope_byte_off + half_off, packed)
tl.store(scale_base + NUM_NOPE_BLOCKS + b, ue8m0)
# MXFP4 weight-fold contract:
# index_weights_out = index_weights * softmax_scale * head_scale
# NOTE: q_scale is NOT folded here (contrast with the FP8 kernel above).
# MXFP4 Q emits a separate ue8m0 scale tensor of shape
# (T, H, HEAD_DIM // MXFP4_BLOCK) alongside the packed values, so each
# per-block scale is applied by the downstream MXFP4 logits kernel when
# dequantizing Q — there is no per-token scalar to fold into `weights`.
index_weights = tl.load(
index_weights_ptr + tok_idx * index_weights_stride + head_idx
).to(tl.float32)
index_weights *= index_weights_softmax_scale
index_weights *= index_weights_head_scale
tl.store(
index_weights_out_ptr + tok_idx * index_weights_out_stride + head_idx,
index_weights,
)
def fused_indexer_q_rope_quant(
positions: torch.Tensor,
index_q: torch.Tensor,
index_q_cos_sin_cache: torch.Tensor,
# Index weights
index_weights: torch.Tensor,
index_weights_softmax_scale: float,
index_weights_head_scale: float,
use_fp4: bool = False,
) -> tuple[
torch.Tensor | tuple[torch.Tensor, torch.Tensor],
torch.Tensor,
]:
"""Fused RoPE + quantize Q for the sparse indexer.
Weight-fold semantics (important — the two paths differ):
FP8 path (use_fp4=False, default):
q_fp8 : (T, H, HEAD_DIM) float8_e4m3fn, per-token-per-head
scalar scale (NOT stored — folded into weights below)
weights_out = weights * q_scale * softmax_scale * head_scale
Rationale: a single per-token q_scale is a scalar the downstream FP8
logits kernel would otherwise multiply in. Folding it into `weights`
avoids emitting a separate tensor and is free for the logits kernel.
MXFP4 path (use_fp4=True):
q_packed : (T, H, HEAD_DIM // 2) uint8 (2 E2M1 nibbles per byte)
q_scale : (T, H, HEAD_DIM // MXFP4_BLOCK_SIZE) uint8 ue8m0 bytes
weights_out = weights * softmax_scale * head_scale
Rationale: MXFP4 has PER-BLOCK (32-element) scales that live with
the Q values — they cannot be folded into a per-token weight
scalar, so `weights` carries only the softmax and head scales.
Returns (q_quant, weights_out) where q_quant is either a Tensor (FP8) or
a (values, scales) tuple (MXFP4). This matches the union type accepted
by `SparseAttnIndexer.forward_*`.
"""
assert positions.ndim == 1
assert index_q.ndim == 3
assert index_q_cos_sin_cache.ndim == 2
num_tokens = positions.shape[0]
num_index_q_heads = index_q.shape[1]
index_q_head_dim = index_q.shape[2]
index_weights_out = torch.empty_like(index_weights, dtype=torch.float32)
if use_fp4:
assert index_q_head_dim % MXFP4_BLOCK_SIZE == 0, (
f"head_dim={index_q_head_dim} must be a multiple of MXFP4 block "
f"size {MXFP4_BLOCK_SIZE}"
)
num_scale_blocks = index_q_head_dim // MXFP4_BLOCK_SIZE
index_q_packed = torch.empty(
(num_tokens, num_index_q_heads, index_q_head_dim // 2),
dtype=torch.uint8,
device=index_q.device,
)
index_q_scale = torch.empty(
(num_tokens, num_index_q_heads, num_scale_blocks),
dtype=torch.uint8,
device=index_q.device,
)
_fused_indexer_q_rope_mxfp4_kernel[(num_tokens, num_index_q_heads)](
positions,
index_q,
index_q.stride(0),
index_q.stride(1),
index_q_cos_sin_cache,
index_q_cos_sin_cache.stride(0),
index_q_cos_sin_cache.shape[-1] // 2,
index_q_packed,
index_q_packed.stride(0),
index_q_packed.stride(1),
index_q_scale,
index_q_scale.stride(0),
index_q_scale.stride(1),
index_q_head_dim,
MXFP4_BLOCK_SIZE,
index_weights,
index_weights.stride(0),
index_weights_softmax_scale,
index_weights_head_scale,
index_weights_out,
index_weights_out.stride(0),
num_warps=1, # TODO: Tune this
)
# Values stay uint8 (2 E2M1 nibbles per byte). Scales are 4 ue8m0
# bytes per (token, head) reinterpreted as one int32, then squeezed
# from (T, H, 1) to (T, H) to match DeepGEMM's expected q_sf rank
# (prefill wants 2-D (seq_len, num_heads); decode reshapes this to
# 3-D (batch, next_n, num_heads)).
return (
index_q_packed,
index_q_scale.view(torch.int32).squeeze(-1),
), index_weights_out
index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn)
_fused_indexer_q_rope_quant_kernel[(num_tokens, num_index_q_heads)](
positions,
index_q,
index_q.stride(0),
index_q.stride(1),
index_q_cos_sin_cache,
index_q_cos_sin_cache.stride(0),
index_q_cos_sin_cache.shape[-1] // 2,
index_q_fp8,
index_q_fp8.stride(0),
index_q_fp8.stride(1),
index_q_head_dim,
index_weights,
index_weights.stride(0),
index_weights_softmax_scale,
index_weights_head_scale,
index_weights_out,
index_weights_out.stride(0),
num_warps=1, # TODO: Tune this
)
return index_q_fp8, index_weights_out
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Fused inverse RoPE + block-scaled FP8 quantization kernel for DeepseekV4 attention.
Output scale format is pre-transformed (MN-major TMA-aligned; FP32 on SM90,
INT32-packed UE8M0 on SM100) so fp8_einsum skips transform_sf_into_required_layout.
"""
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _fused_inv_rope_fp8_quant_per_head(
o_ptr,
positions_ptr,
cos_sin_cache_ptr,
fp8_ptr,
scale_ptr,
num_tokens,
heads_per_group: tl.constexpr,
o_stride_token,
o_stride_head,
cache_stride_pos,
fp8_stride_group,
fp8_stride_token,
scale_stride_group,
scale_stride_k,
fp8_max: tl.constexpr,
eps: tl.constexpr,
QUANT_GROUP_SIZE: tl.constexpr,
CHUNKS_PER_HEAD: tl.constexpr,
ROPE_START: tl.constexpr,
HALF_ROPE: tl.constexpr,
TMA_ALIGNED_SCALES: tl.constexpr,
):
# int64: stride multiply overflows int32 past num_tokens=32768 (IMA).
pid_token = tl.program_id(0).to(tl.int64)
pid_gh = tl.program_id(1).to(tl.int64)
g = pid_gh // heads_per_group
head_in_group = pid_gh % heads_per_group
global_head = pid_gh
qb_start = head_in_group * CHUNKS_PER_HEAD
# Padding rows in the TMA-aligned scale buffer: fill with zero and skip quant.
if pid_token >= num_tokens:
if TMA_ALIGNED_SCALES:
scale_addr = (
scale_ptr
+ g * scale_stride_group
+ pid_token
+ head_in_group * scale_stride_k
)
tl.store(scale_addr, tl.zeros((), dtype=tl.int32))
else:
block_offsets = tl.arange(0, CHUNKS_PER_HEAD)
qb_indices = qb_start + block_offsets
scale_addrs = (
scale_ptr
+ g * scale_stride_group
+ pid_token
+ qb_indices * scale_stride_k
)
tl.store(scale_addrs, tl.zeros((CHUNKS_PER_HEAD,), dtype=tl.float32))
return
input_base = o_ptr + pid_token * o_stride_token + global_head * o_stride_head
HEAD_DIM: tl.constexpr = CHUNKS_PER_HEAD * QUANT_GROUP_SIZE
offsets = tl.arange(0, HEAD_DIM)
x = tl.load(input_base + offsets).to(tl.float32)
rope_abs_start: tl.constexpr = (CHUNKS_PER_HEAD - 1) * QUANT_GROUP_SIZE + ROPE_START
pos = tl.load(positions_ptr + pid_token)
cache_base = cos_sin_cache_ptr + pos * cache_stride_pos
is_rope = offsets >= rope_abs_start
rope_local = offsets - rope_abs_start
x_partner = tl.load(input_base + (offsets ^ 1), mask=is_rope, other=0.0).to(
tl.float32
)
cs_idx = tl.maximum(rope_local >> 1, 0)
cos_v = tl.load(cache_base + cs_idx, mask=is_rope, other=1.0)
sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope, other=0.0)
x_add = x * cos_v + x_partner * sin_v
x_sub = x * cos_v - x_partner * sin_v
is_even = (rope_local & 1) == 0
rotated = tl.where(is_even, x_add, x_sub)
x = tl.where(is_rope, rotated, x)
x_2d = tl.reshape(tl.abs(x), (CHUNKS_PER_HEAD, QUANT_GROUP_SIZE))
block_absmax = tl.maximum(tl.max(x_2d, axis=1), eps)
scale_raw = block_absmax * (1.0 / fp8_max)
scales = tl.math.exp2(tl.ceil(tl.log2(scale_raw)))
scales_exp = tl.reshape(
tl.broadcast_to(
tl.reshape(scales, (CHUNKS_PER_HEAD, 1)),
(CHUNKS_PER_HEAD, QUANT_GROUP_SIZE),
),
(HEAD_DIM,),
)
x_quant = tl.clamp(x / scales_exp, -fp8_max, fp8_max).to(tl.float8e4nv)
fp8_base = (
fp8_ptr
+ g * fp8_stride_group
+ pid_token * fp8_stride_token
+ qb_start * QUANT_GROUP_SIZE
)
tl.store(fp8_base + offsets, x_quant)
block_offsets = tl.arange(0, CHUNKS_PER_HEAD)
qb_indices = qb_start + block_offsets
if TMA_ALIGNED_SCALES:
scale_bits = scales.to(tl.int32, bitcast=True)
ue8m0_bytes = (scale_bits >> 23) & 0xFF
packed_val = tl.sum(ue8m0_bytes << (block_offsets * 8))
scale_addr = (
scale_ptr
+ g * scale_stride_group
+ pid_token
+ head_in_group * scale_stride_k
)
tl.store(scale_addr, packed_val)
else:
scale_addrs = (
scale_ptr + g * scale_stride_group + pid_token + qb_indices * scale_stride_k
)
tl.store(scale_addrs, scales)
def fused_inv_rope_fp8_quant(
o: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
n_groups: int,
heads_per_group: int,
nope_dim: int = 448,
rope_dim: int = 64,
quant_group_size: int = 128,
tma_aligned_scales: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fused inverse RoPE + block-scaled FP8 quantization.
Args:
o: Attention output [num_tokens, num_heads, head_dim] bf16.
positions: Token positions [num_tokens] int64.
cos_sin_cache: Precomputed [max_pos, rope_dim] with cos||sin.
n_groups: Number of output groups.
heads_per_group: Heads per group.
nope_dim: Non-RoPE dimensions per head (default 448).
rope_dim: RoPE dimensions per head (default 64).
quant_group_size: FP8 quantization block size (default 128).
tma_aligned_scales: Output INT32 packed UE8M0 for SM100 (True)
or FP32 for SM90 (False).
Returns:
o_fp8: [T, G, D] float8_e4m3fn, strides (D, T*D, 1).
o_scale: Pre-transformed scale tensor for fp8_einsum.
"""
from vllm.utils.deep_gemm import get_tma_aligned_size
num_tokens, num_heads, head_dim = o.shape
assert num_heads == n_groups * heads_per_group
assert head_dim == nope_dim + rope_dim
assert head_dim % quant_group_size == 0
assert nope_dim % quant_group_size == (quant_group_size - rope_dim)
assert rope_dim % 2 == 0
assert cos_sin_cache.shape[-1] == rope_dim
assert cos_sin_cache.dtype == torch.float32
d = heads_per_group * head_dim
num_scale_blocks = d // quant_group_size
chunks_per_head = head_dim // quant_group_size
fp8_dtype = torch.float8_e4m3fn
fp8_max = torch.finfo(fp8_dtype).max
fp8_buf = torch.empty(
(n_groups, num_tokens, d),
dtype=fp8_dtype,
device=o.device,
)
tma_aligned_T = get_tma_aligned_size(num_tokens, 4)
if tma_aligned_scales:
packed_sf_k = (num_scale_blocks + 3) // 4
scale_buf = torch.empty(
n_groups * packed_sf_k * tma_aligned_T,
dtype=torch.int32,
device=o.device,
).as_strided(
(n_groups, num_tokens, packed_sf_k),
(packed_sf_k * tma_aligned_T, 1, tma_aligned_T),
)
else:
scale_buf = torch.empty(
n_groups * num_scale_blocks * tma_aligned_T,
dtype=torch.float32,
device=o.device,
).as_strided(
(n_groups, num_tokens, num_scale_blocks),
(num_scale_blocks * tma_aligned_T, 1, tma_aligned_T),
)
common_args = dict(
heads_per_group=heads_per_group,
o_stride_token=o.stride(0),
o_stride_head=o.stride(1),
cache_stride_pos=cos_sin_cache.stride(0),
fp8_stride_group=fp8_buf.stride(0),
fp8_stride_token=fp8_buf.stride(1),
scale_stride_group=scale_buf.stride(0),
scale_stride_k=scale_buf.stride(2),
fp8_max=fp8_max,
eps=1e-10,
QUANT_GROUP_SIZE=quant_group_size,
CHUNKS_PER_HEAD=chunks_per_head,
ROPE_START=nope_dim % quant_group_size,
HALF_ROPE=rope_dim // 2,
TMA_ALIGNED_SCALES=tma_aligned_scales,
num_stages=1,
launch_pdl=False,
)
grid = (tma_aligned_T, n_groups * heads_per_group)
_fused_inv_rope_fp8_quant_per_head[grid](
o,
positions,
cos_sin_cache,
fp8_buf,
scale_buf,
num_tokens,
**common_args,
num_warps=1,
)
return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _fused_q_kv_rmsnorm_kernel(
q_ptr,
q_out_ptr,
q_weight_ptr,
q_in_stride,
q_out_stride,
kv_ptr,
kv_out_ptr,
kv_weight_ptr,
kv_in_stride,
kv_out_stride,
eps,
Q_SIZE: tl.constexpr,
KV_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# num_tokens goes on grid-x (max 2**31 - 1); task goes on grid-y.
# CUDA's grid-y/z are capped at 65535, so putting num_tokens there crashes
# the launch at max-num-batched-tokens >= 65536 with "invalid argument".
# int64: q_in_stride can be ~24K (128 heads × 192) and overflows int32
# past num_tokens ~87K under large chunked prefill.
token_idx = tl.program_id(0).to(tl.int64)
pid_task = tl.program_id(1)
if pid_task == 0:
SIZE = Q_SIZE
row_in = q_ptr + token_idx * q_in_stride
weight_ptr = q_weight_ptr
row_out = q_out_ptr + token_idx * q_out_stride
else:
SIZE = KV_SIZE
row_in = kv_ptr + token_idx * kv_in_stride
weight_ptr = kv_weight_ptr
row_out = kv_out_ptr + token_idx * kv_out_stride
# RMSNorm in fp32 throughout — matches csrc/layernorm_kernels.cu's
# `(scalar_t)(x * s_variance * w)` and DeepseekV4's compressor kernel, which
# keep x, rrms, and w all in fp32 and perform a single cast at store.
block = tl.arange(0, BLOCK_SIZE)
mask = block < SIZE
x = tl.load(row_in + block, mask=mask, other=0.0).to(tl.float32)
variance = tl.sum(x * x, axis=0) / SIZE
rrms = tl.rsqrt(variance + eps)
w = tl.load(weight_ptr + block, mask=mask, other=0.0).to(tl.float32)
y = x * rrms * w
tl.store(row_out + block, y.to(row_out.dtype.element_ty), mask=mask)
def fused_q_kv_rmsnorm(
qr: torch.Tensor,
kv: torch.Tensor,
q_weight: torch.Tensor,
kv_weight: torch.Tensor,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
assert qr.ndim == 2 and kv.ndim == 2
assert qr.shape[0] == kv.shape[0], (
f"token dim mismatch: qr={qr.shape}, kv={kv.shape}"
)
assert qr.stride(-1) == 1 and kv.stride(-1) == 1
assert q_weight.is_contiguous() and kv_weight.is_contiguous()
q_size = qr.shape[1]
kv_size = kv.shape[1]
num_tokens = qr.shape[0]
qr_out = torch.empty_like(qr)
kv_out = torch.empty_like(kv)
if num_tokens == 0:
return qr_out, kv_out
block_size = triton.next_power_of_2(max(q_size, kv_size))
_fused_q_kv_rmsnorm_kernel[(num_tokens, 2)](
qr,
qr_out,
q_weight,
qr.stride(0),
qr_out.stride(0),
kv,
kv_out,
kv_weight,
kv.stride(0),
kv_out.stride(0),
eps,
Q_SIZE=q_size,
KV_SIZE=kv_size,
BLOCK_SIZE=block_size,
)
return qr_out, kv_out
......@@ -54,8 +54,14 @@ class KVCacheCoordinator(ABC):
metrics_collector,
)
# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
# KV cache group indices that get the EAGLE last-block drop.
self.eagle_group_ids: set[int] = {
i for i, g in enumerate(kv_cache_config.kv_cache_groups) if g.is_eagle_group
}
# Conservatively fall back to flag all groups when no group is flagged.
if use_eagle and not self.eagle_group_ids:
self.eagle_group_ids = set(range(len(kv_cache_config.kv_cache_groups)))
self.single_type_managers = tuple(
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_group.kv_cache_spec,
......@@ -357,7 +363,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
kv_cache_group_ids=[0],
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
use_eagle=0 in self.eagle_group_ids,
alignment_tokens=self.block_size,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
......@@ -450,6 +456,14 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
block_sizes = [spec.block_size for spec, _, _ in attention_groups]
self.lcm_block_size = lcm(*block_sizes)
# Attention-group indices (into ``self.attention_groups``) that
# contain at least one EAGLE/MTP KV cache group.
self.eagle_attn_group_indices: set[int] = {
i
for i, (_, group_ids, _) in enumerate(self.attention_groups)
if any(gid in self.eagle_group_ids for gid in group_ids)
}
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
......@@ -485,49 +499,62 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
hit_blocks_by_group: list[list[KVCacheBlock] | None] = [None] * num_groups
# Simple hybrid (1 full attn + 1 other): one iteration suffices.
# Full attn is always first if it exists. This avoids EAGLE drops
# being applied multiple times to non-full-attn groups.
# FIXME (yifan): However, for complex hybrid models with multiple attn
# groups, we still have the EAGLE spiral block dropping problem. See
# discussion in issue https://github.com/vllm-project/vllm/issues/32802.
# Full attn is always first if it exists.
is_simple_hybrid = len(self.attention_groups) == 2 and isinstance(
self.attention_groups[0][0], FullAttentionSpec
)
# Attention-group indices whose EAGLE drop is verified at the current
# ``curr_hit_length``. Each eagle group applies the drop at most once
# per candidate length (see issue #32802).
eagle_verified: set[int] = set()
while True:
curr_hit_length = hit_length
for spec, group_ids, manager_cls in self.attention_groups:
is_full_attn = isinstance(spec, FullAttentionSpec)
# Full attention: reuse cached blocks (downward-closed property)
for idx, (spec, group_ids, manager_cls) in enumerate(self.attention_groups):
cached_blocks = hit_blocks_by_group[group_ids[0]]
if is_full_attn and cached_blocks is not None:
# For full attention, we only need to compute the cache hit
# length once. Starting from the second iteration, if the
# curr_hit_length is reduced by other groups, we can simply
# keep the first (curr_hit_length // block_size) blocks from
# the last iteration.
num_blocks = curr_hit_length // spec.block_size
curr_hit_length = num_blocks * spec.block_size
else:
hit_blocks = manager_cls.find_longest_cache_hit(
block_hashes=_get_block_hashes(spec),
max_length=curr_hit_length,
kv_cache_group_ids=group_ids,
block_pool=self.block_pool,
kv_cache_spec=spec,
use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
if isinstance(spec, FullAttentionSpec) and cached_blocks is not None:
# Full attention is downward-closed: we only need to look
# up cached blocks once; on subsequent iterations just trim
# to the (reduced) current hit length.
curr_hit_length = (
curr_hit_length // spec.block_size * spec.block_size
)
curr_hit_length = len(hit_blocks[0]) * spec.block_size
for group_id, blocks in zip(group_ids, hit_blocks):
hit_blocks_by_group[group_id] = blocks
continue
use_eagle = (
idx in self.eagle_attn_group_indices and idx not in eagle_verified
)
_max_length = curr_hit_length
if use_eagle:
# Eagle needs to match one more block and then pop the last.
_max_length = min(
curr_hit_length + spec.block_size, max_cache_hit_length
)
hit_blocks = manager_cls.find_longest_cache_hit(
block_hashes=_get_block_hashes(spec),
max_length=_max_length,
kv_cache_group_ids=group_ids,
block_pool=self.block_pool,
kv_cache_spec=spec,
use_eagle=use_eagle,
alignment_tokens=self.lcm_block_size,
)
_new_hit_length = len(hit_blocks[0]) * spec.block_size
if use_eagle:
eagle_verified.add(idx)
elif _new_hit_length < curr_hit_length:
# length shrunk; invalidate previous eagle verifications
eagle_verified.clear()
curr_hit_length = _new_hit_length
for group_id, blocks in zip(group_ids, hit_blocks):
hit_blocks_by_group[group_id] = blocks
if curr_hit_length >= hit_length:
break
hit_length = curr_hit_length
# Simple hybrid: exit after one iteration
if is_simple_hybrid:
break
......
......@@ -4,18 +4,19 @@
import copy
import hashlib
import math
import os
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace
from functools import partial
from typing import Any, NewType, TypeAlias, overload
from typing import Any, NewType, TypeAlias, cast, overload
from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.hashing import sha256_cbor, xxhash_cbor
from vllm.utils.math_utils import cdiv
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.mem_utils import format_gib
from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
......@@ -24,6 +25,9 @@ from vllm.v1.kv_cache_interface import (
KVCacheGroupSpec,
KVCacheSpec,
KVCacheTensor,
MambaSpec,
MLAAttentionSpec,
SlidingWindowMLASpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
)
......@@ -562,6 +566,72 @@ def hash_block_tokens(
)
def resolve_kv_cache_block_sizes(
kv_cache_config: KVCacheConfig,
vllm_config: VllmConfig,
) -> tuple[int, int]:
"""Resolve (scheduler_block_size, hash_block_size).
- ``scheduler_block_size`` is the token-alignment invariant used by the
scheduler (e.g. for ``num_computed_tokens`` rounding). Single group:
``cache_config.block_size * dcp * pcp``. Multiple groups: LCM of every
group's block size — context parallelism is not supported here.
- ``hash_block_size`` is the granularity at which ``Request.block_hashes``
is computed. Single group: equals scheduler block size. Multiple groups:
``cache_config.hash_block_size`` override if set, else the GCD of group
block sizes; every group's block size must be divisible by it. Returns
the scheduler block size (i.e. disables finer hashing) if block hashing
is inactive or a mamba group's block size diverges from the cache
block size (mamba_cache_mode != "align").
"""
cache_config = vllm_config.cache_config
dcp = vllm_config.parallel_config.decode_context_parallel_size
pcp = vllm_config.parallel_config.prefill_context_parallel_size
groups = kv_cache_config.kv_cache_groups
if len(groups) <= 1: # Single group: block_size * dcp * pcp
bs = cache_config.block_size * dcp * pcp
return bs, bs
if dcp != 1 or pcp != 1:
raise ValueError(
"Hybrid KV cache groups with multiple block sizes do not "
"support context parallelism (dcp_world_size/pcp_world_size > 1)."
)
group_block_sizes = [g.kv_cache_spec.block_size for g in groups]
scheduler_block_size = math.lcm(*group_block_sizes)
# Block hashes are only consumed by prefix caching and KV connectors
# (P/D, offloading); when neither is active, keep hash_block_size equal
# to the scheduler block size.
connector_enabled = vllm_config.kv_transfer_config is not None
if not (cache_config.enable_prefix_caching or connector_enabled):
return scheduler_block_size, scheduler_block_size
# Mamba groups with block_size != cache_config.block_size
# (mamba_cache_mode != "align") break divisibility; back off to the
# scheduler block size.
if any(
isinstance(g.kv_cache_spec, MambaSpec)
and g.kv_cache_spec.block_size != cache_config.block_size
for g in groups
):
return scheduler_block_size, scheduler_block_size
requested = cache_config.hash_block_size
hash_block_size = (
requested if requested is not None else math.gcd(*group_block_sizes)
)
if any(bs % hash_block_size != 0 for bs in group_block_sizes):
raise ValueError(
f"Invalid hash_block_size={hash_block_size}; all KV cache group "
f"block sizes must be divisible by hash_block_size. "
f"Got group block sizes={group_block_sizes}."
)
return scheduler_block_size, hash_block_size
def get_request_block_hasher(
block_size: int,
caching_hash_fn: Callable[[Any], bytes],
......@@ -1089,6 +1159,63 @@ def _get_kv_cache_groups_uniform_page_size(
return create_kv_cache_group_specs(kv_cache_spec, grouped_layers)
def _get_kv_cache_config_deepseek_v4(
vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec],
available_memory: int,
) -> tuple[int, list[KVCacheTensor]]:
"""DeepseekV4 KV cache tensor layout planning.
Precondition: kv_cache_groups[0] is the full-MLA group; its page sizes
define the canonical bucket set. Non-full-MLA groups must have been
page_size-padded upstream (see _get_kv_cache_groups_uniform_groups) so
every layer's page_size matches one of the full-MLA bucket sizes.
For each group, bucket its layers by page_size_bytes and place each
layer at tuple_idx = position-within-bucket. Emit one KVCacheTensor
per (tuple_idx, bucket) whose shared_by is the union of per-group
layers at that slot.
"""
full_mla_spec = kv_cache_groups[0].kv_cache_spec
assert isinstance(full_mla_spec, UniformTypeKVCacheSpecs)
page_sizes = sorted(full_mla_spec.get_page_sizes())
layer_tuple_page_bytes = sum(page_sizes)
# Pre-bucket each group's layers by page_size (registration order within
# bucket). bucketed[g_idx][page_size] = [layer_name, ...].
bucketed: list[dict[int, list[str]]] = []
for group in kv_cache_groups:
assert isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs)
specs = group.kv_cache_spec.kv_cache_specs
b: dict[int, list[str]] = defaultdict(list)
for name in group.layer_names:
b[specs[name].page_size_bytes].append(name)
bucketed.append(b)
# num_layer_tuples = longest bucket list across all groups. For the
# full-MLA group this equals the count of layers in the largest
# per-page-size bucket (= get_num_layer_tuples()); for SWA sub-groups
# this equals the sub-group size (each has a single page_size).
num_layer_tuples = max(len(layers) for b in bucketed for layers in b.values())
num_blocks = available_memory // (layer_tuple_page_bytes * num_layer_tuples)
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
kv_cache_tensors: list[KVCacheTensor] = []
for tuple_idx in range(num_layer_tuples):
for ps in page_sizes:
shared_by: list[str] = []
for b in bucketed:
bucket = b.get(ps)
if bucket is not None and tuple_idx < len(bucket):
shared_by.append(bucket[tuple_idx])
kv_cache_tensors.append(
KVCacheTensor(size=ps * num_blocks, shared_by=shared_by)
)
return num_blocks, kv_cache_tensors
def get_kv_cache_config_from_groups(
vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec],
......@@ -1120,7 +1247,7 @@ def get_kv_cache_config_from_groups(
kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs
):
# Special case: all layers have the same type of KV cache but with
# different hidden size. Allocate different amount of memory for each
# different hidden sizes. Allocate different amount of memory for each
# layer based on its hidden size.
num_blocks = (
available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes
......@@ -1136,6 +1263,15 @@ def get_kv_cache_config_from_groups(
)
for layer_name in kv_cache_groups[0].layer_names
]
elif all(
isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs)
for group in kv_cache_groups
):
# DeepseekV4: UniformTypeKVCacheSpecs but multiple groups.
# Delegate to the DeepseekV4-specific allocator.
num_blocks, kv_cache_tensors = _get_kv_cache_config_deepseek_v4(
vllm_config, kv_cache_groups, available_memory
)
else:
# General case:
# We will have group_size memory pools, each is shared by one layer from
......@@ -1206,9 +1342,41 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
has_chunked_local_attention = any(
isinstance(spec, ChunkedLocalAttentionSpec) for spec in kv_cache_spec.values()
)
has_swa_mla = any(
isinstance(spec, SlidingWindowMLASpec) for spec in kv_cache_spec.values()
)
uniform_block_size: int | None = None
if has_swa_mla:
# For DeepseekV4, block sizes can be different for different KV cache groups.
# E.g., Full MLA: 256; SWA MLA: 64; C4 partial states: 4, C128 states: 8.
assert has_full_attention
any_full_spec = next(
iter(
spec
for spec in kv_cache_spec.values()
if isinstance(spec, FullAttentionSpec)
)
)
uniform_block_size = any_full_spec.block_size
if has_full_attention and (has_sliding_window or has_chunked_local_attention):
for layer_name, spec in kv_cache_spec.items():
if isinstance(spec, SlidingWindowSpec):
if isinstance(spec, SlidingWindowMLASpec):
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=uniform_block_size
if uniform_block_size is not None
else spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
page_size_padded=spec.page_size_padded,
cache_dtype_str=spec.cache_dtype_str,
alignment=spec.alignment,
compress_ratio=spec.compress_ratio,
model_version=spec.model_version,
)
elif isinstance(spec, SlidingWindowSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
......@@ -1237,6 +1405,204 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
)
def group_and_unify_kv_cache_specs(
kv_cache_spec: dict[str, KVCacheSpec],
) -> list[UniformTypeKVCacheSpecs] | None:
"""
Group the KV cache specs and unify each group into one UniformTypeKVCacheSpecs.
Currently, this is only used for DeepseekV4.
"""
if not any(
isinstance(spec, SlidingWindowMLASpec) for spec in kv_cache_spec.values()
):
return None
mla_specs: dict[str, KVCacheSpec] = {}
grouped_swa_mla_specs: dict[tuple[int, int], dict[str, KVCacheSpec]] = defaultdict(
dict
)
# NOTE: Here we group SWA layers by (block_size, sliding_window), which separates
# SWA layers, C4I+C4A layers, and C128A layers into three different groups. It can
# be fragile with only block_size and sliding_window as keys, but fine for now.
for name, spec in kv_cache_spec.items():
if isinstance(spec, SlidingWindowMLASpec):
grouped_swa_mla_specs[(spec.block_size, spec.sliding_window)][name] = spec
elif isinstance(spec, MLAAttentionSpec):
mla_specs[name] = spec
assert len(mla_specs) > 0
mla_uniform_spec = UniformTypeKVCacheSpecs.from_specs(mla_specs)
assert mla_uniform_spec is not None
swa_uniform_specs: list[UniformTypeKVCacheSpecs] = []
for spec_dict in grouped_swa_mla_specs.values():
uniform_spec = UniformTypeKVCacheSpecs.from_specs(spec_dict)
assert uniform_spec is not None
swa_uniform_specs.append(uniform_spec)
return [mla_uniform_spec, *swa_uniform_specs]
def _approximate_gcd(values: Sequence[int], *, lower_bound: int | None = None) -> int:
"""Pick a chunk size that minimizes total upward padding.
Each x is rounded up to a multiple of d:
x -> ceil(x / d) * d
Total padding is:
pad(d) = sum_i (ceil(x_i / d) * d - x_i)
We brute-force d in [lower_bound, max(values)] (fine for small lists / small
maxima) and return the d with minimum padding. Ties prefer larger d.
"""
if not values:
raise ValueError("values must be non-empty")
if any(x <= 0 for x in values):
raise ValueError(f"values must be positive, got: {list(values)!r}")
min_d = max(1, lower_bound if lower_bound is not None else 1)
max_d = max(values)
if min_d > max_d:
return min_d
best_d = min_d
best_pad: int | None = None
for d in range(min_d, max_d + 1):
pad = sum((d - (x % d)) % d for x in values)
if best_pad is None or pad < best_pad or (pad == best_pad and d > best_d):
best_pad = pad
best_d = d
return best_d
def _get_kv_cache_groups_uniform_groups(
grouped_specs: list[UniformTypeKVCacheSpecs],
) -> list[KVCacheGroupSpec]:
"""
Generate the KV cache groups from the grouped specs.
"""
assert len(grouped_specs) > 0 and all(
isinstance(spec, UniformTypeKVCacheSpecs) for spec in grouped_specs
)
# For now, we restrict the first grouped_spec to be UniformTypeKVCacheSpecs
# containing only MLAAttentionSpec.
full_mla_spec = grouped_specs[0]
assert all(
isinstance(spec, MLAAttentionSpec)
for spec in full_mla_spec.kv_cache_specs.values()
)
full_mla_group = KVCacheGroupSpec(
layer_names=list(full_mla_spec.kv_cache_specs.keys()),
kv_cache_spec=full_mla_spec,
)
# We define a layer tuple as a group of layers with different page sizes, and
# one UniformTypeKVCacheSpecs contains a list of layer tuples.
# For example, if we have 11 C4 layers and 10 C128 layers, we can define a layer
# tuple as [C4I, C4A, C128], and the full_mla_group will contain "11" layer tuples.
# The other uniform KV cache specs will be similarly partitioned into layer tuples.
# Say we have 21 SWA layers, all with the same page size, then we will have "21"
# layer tuples.
num_layer_tuples_per_group: list[int] = [
g_spec.get_num_layer_tuples() for g_spec in grouped_specs
]
# Choose `num_layer_tuples` to minimize total padding across groups.
num_layer_tuples = _approximate_gcd(
num_layer_tuples_per_group, lower_bound=num_layer_tuples_per_group[0]
)
# Round up to the nearest multiple of `num_layer_tuples` (i.e., padding)
num_layer_tuples_per_group = [
round_up(x, num_layer_tuples) for x in num_layer_tuples_per_group
]
swa_mla_specs = grouped_specs[1:]
assert all(
isinstance(spec, SlidingWindowMLASpec)
for group in swa_mla_specs
for spec in group.kv_cache_specs.values()
)
# Split each SWA UniformKV group into smaller groups to align their #(layer tuples)
# Possibly padding layer tuples for this.
# Additionally, we also pad KV blocks in each SWA layer, to align the page size
# with the corresponding layer in the full-MLA group.
all_page_sizes = full_mla_spec.get_page_sizes()
swa_mla_groups = []
for sm_spec in swa_mla_specs:
sm_page_sizes = sm_spec.get_page_sizes()
layers_per_size: dict[int, list[str]] = defaultdict(list)
assert max(sm_page_sizes) <= max(all_page_sizes)
# Unify page size by padding layers' page_size to the nearest larger page_size.
# Compute candidate (nearest larger page_size) for each unique page size.
size_to_candidate: dict[int, int] = {}
for ps in sm_page_sizes:
size_to_candidate[ps] = min(x for x in all_page_sizes if x >= ps)
# Pad and collect layer names per page size.
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
current_size = layer_spec.page_size_bytes
candidate = size_to_candidate[current_size]
if current_size < candidate:
object.__setattr__(layer_spec, "page_size_padded", candidate)
layers_per_size[candidate].append(layer_name)
# NOTE(yifan): for now, inside a UniformKV group, each page_size should
# have the same number of layers. This also means we don't need to pad layers
# inside a partial-full layer tuple.
assert len(set(len(layers) for layers in layers_per_size.values())) == 1
num_layers_per_size = len(next(iter(layers_per_size.values())))
# Split layers inside each UniformKV group for aligned #(layers).
# See `_get_kv_cache_groups_uniform_page_size` for more details.
num_tuple_groups = cdiv(num_layers_per_size, num_layer_tuples)
layer_tuples = list(zip(*layers_per_size.values()))
for i in range(num_tuple_groups):
group_layer_tuples = layer_tuples[i::num_tuple_groups]
# Flatten tuples and build dict for from_specs
group_layer_names = [
name for layer_tuple in group_layer_tuples for name in layer_tuple
]
group_layer_specs = {
name: sm_spec.kv_cache_specs[name] for name in group_layer_names
}
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
assert sub_sm_spec is not None
swa_mla_groups.append(
KVCacheGroupSpec(
layer_names=group_layer_names,
kv_cache_spec=sub_sm_spec,
)
)
return [full_mla_group, *swa_mla_groups]
def _annotate_eagle_groups_deepseek_v4(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
kv_cache_groups: list[KVCacheGroupSpec],
) -> None:
spec_config = vllm_config.speculative_config
if spec_config is None or not spec_config.use_eagle():
return
# Detection uses the merged MLA spec's model_version.
if not any(
getattr(spec, "model_version", None) == "deepseek_v4"
for spec in kv_cache_spec.values()
):
return
# DeepseekV4's MTP attention layer is always the last layer, and we flag whichever
# group contains it.
# FIXME(yifan): avoid/generalize this hacky check.
last_layer = next(reversed(kv_cache_spec))
for group in kv_cache_groups:
if last_layer in group.layer_names:
group.is_eagle_group = True
break
def get_kv_cache_groups(
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec]
) -> list[KVCacheGroupSpec]:
......@@ -1268,6 +1634,14 @@ def get_kv_cache_groups(
# full attention, or all layers are sliding window attention with the
# same window size). Put all layers into one group.
return _get_kv_cache_groups_uniform_type(uniform_spec)
elif grouped_specs := group_and_unify_kv_cache_specs(kv_cache_spec):
# DeepseekV4 case: All layers need the same number of token slots,
# yet some layers are full attention while others are sliding window
# attention in different sizes. Need to group layers into multiple
# UniformTypeKVCacheSpecs.
kv_cache_groups = _get_kv_cache_groups_uniform_groups(grouped_specs)
_annotate_eagle_groups_deepseek_v4(vllm_config, kv_cache_spec, kv_cache_groups)
return kv_cache_groups
# As KVCacheManager can only allocate memory of one size, we need to unify
# the page size of the layers. For cases cannot be unified, this function
......@@ -1360,15 +1734,40 @@ def _max_memory_usage_bytes_from_groups(
if not kv_cache_groups:
return 0
# UniformTypeKVCacheSpecs special case (single group, per-layer specs)
if len(kv_cache_groups) == 1 and isinstance(
kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs
):
# UniformTypeKVCacheSpecs special case (single group, per-layer specs)
per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs
return sum(
spec.max_memory_usage_bytes(vllm_config)
for spec in per_layer_specs.values()
)
elif all(
isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs)
for group in kv_cache_groups
):
# Special case (only DeepseekV4 for now): all groups are
# UniformTypeKVCacheSpecs.
# They must already be page_size aligned and share a common padded
# layer-tuple layout. Even groups with fewer actual tuples still reserve
# the global number of tuple slots in the shared tensor layout.
full_mla_spec = cast(UniformTypeKVCacheSpecs, kv_cache_groups[0].kv_cache_spec)
layer_tuple_bytes = sum(full_mla_spec.get_page_sizes())
num_layer_tuples = max(
cast(UniformTypeKVCacheSpecs, group.kv_cache_spec).get_num_layer_tuples()
for group in kv_cache_groups
)
total_max_mem_usage_bytes = 0
for group in kv_cache_groups:
group_spec = cast(UniformTypeKVCacheSpecs, group.kv_cache_spec)
g_max_mem_usage_pages = group_spec.max_memory_usage_pages(vllm_config)
g_max_mem_usage_page_bytes = (
num_layer_tuples * g_max_mem_usage_pages * layer_tuple_bytes
)
total_max_mem_usage_bytes += g_max_mem_usage_page_bytes
return total_max_mem_usage_bytes
# General case: group_size pools, each shared by one layer per group
# Memory = group_size * page_size * blocks_for_max_len
......@@ -1515,7 +1914,13 @@ def _project_kv_cache_groups_to_worker(
for layer_name in worker_layer_names
},
)
projected_groups.append(KVCacheGroupSpec(worker_layer_names, group_spec))
projected_groups.append(
KVCacheGroupSpec(
worker_layer_names,
group_spec,
is_eagle_group=group.is_eagle_group and bool(worker_layer_names),
)
)
return projected_groups
......@@ -1698,10 +2103,7 @@ class BlockHashListWithBlockSize:
def _get_value_at(self, idx: int) -> BlockHash:
base = idx * self.scale_factor
end = base + self.scale_factor
merged_hash: bytes = self.block_hashes[base]
for i in range(base + 1, end):
merged_hash += self.block_hashes[i]
return BlockHash(merged_hash)
return BlockHash(b"".join(self.block_hashes[base:end]))
BlockHashList = list[BlockHash] | BlockHashListWithBlockSize
......@@ -41,6 +41,7 @@ class SchedulerInterface(ABC):
kv_cache_config: "KVCacheConfig",
structured_output_manager: "StructuredOutputManager",
block_size: int,
hash_block_size: int,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
......
......@@ -71,6 +71,7 @@ class Scheduler(SchedulerInterface):
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
block_size: int,
hash_block_size: int | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
......@@ -222,6 +223,8 @@ class Scheduler(SchedulerInterface):
self.num_lookahead_tokens = self.num_spec_tokens
# Create the KV cache manager.
if hash_block_size is None:
hash_block_size = block_size
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
......@@ -231,7 +234,7 @@ class Scheduler(SchedulerInterface):
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
hash_block_size=self.block_size,
hash_block_size=hash_block_size,
metrics_collector=self.kv_metrics_collector,
)
# Bind GPU block pool to the KV connector. This must happen after
......@@ -2018,7 +2021,7 @@ class Scheduler(SchedulerInterface):
# the connector.
self.kv_cache_manager.remove_skipped_blocks(
request_id=request.request_id,
total_computed_tokens=request.num_tokens,
total_computed_tokens=request.num_computed_tokens,
)
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
......
......@@ -20,6 +20,7 @@ from vllm.v1.kv_cache_interface import (
MambaSpec,
MLAAttentionSpec,
SinkFullAttentionSpec,
SlidingWindowMLASpec,
SlidingWindowSpec,
TQFullAttentionSpec,
)
......@@ -534,12 +535,10 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
):
# Skip prefix matching check if the block is not aligned with
# `alignment_tokens`.
if (
num_contiguous_blocks == 0
and block_size != alignment_tokens # Faster for common case.
and (i + 1) * block_size % alignment_tokens != 0
):
continue
if num_contiguous_blocks == 0 and block_size != alignment_tokens:
post_pop_blocks = i if use_eagle else i + 1
if (post_pop_blocks * block_size) % alignment_tokens != 0:
continue
# Add the cached block to the computed blocks.
for computed, cached in zip(computed_blocks, cached_block):
computed[i] = cached
......@@ -1118,6 +1117,7 @@ spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
TQFullAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
SlidingWindowMLASpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager,
CrossAttentionSpec: CrossAttentionManager,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment