Unverified Commit f4b42df0 authored by Vibhav Agarwal's avatar Vibhav Agarwal Committed by GitHub
Browse files

[Attention Backend] TurboQuant: 2-bit KV cache compression with 4x capacity (#38479)


Signed-off-by: default avatarvibhavagarwal5 <vibhavagarwal5@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 3bfe55a0
......@@ -82,6 +82,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"RocmAiterUnifiedAttentionBackend"
)
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
TURBOQUANT = "vllm.v1.attention.backends.turboquant_attn.TurboQuantAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
# set to None to avoid alias with other backend, whose value is an empty string
CUSTOM = None
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant attention backend for vLLM.
Prefill: Standard scaled dot-product attention on uncompressed K/V,
then quantize K and store K+V into combined cache slot.
Decode: Compute TQ attention scores from compressed cache,
unpack FP16 values, softmax + weighted sum.
Cache layout (no leading 2 dimension):
(num_blocks, block_size, num_kv_heads, slot_size)
where slot_size = key_packed_size + value_fp16_size
Per-head per-position slot layout:
[key_packed (kps bytes) | value_fp16 (D*2 bytes)]
For turboquant_k3v4_nc head_dim=256: [100 bytes key | 512 bytes value] = 612
"""
import functools
import math
from dataclasses import dataclass
from typing import Any, ClassVar
import torch
import torch.nn.functional as F
from vllm.config import get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.triton_utils import triton
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.fa_utils import (
is_flash_attn_varlen_func_available,
)
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
from vllm.v1.attention.ops.triton_turboquant_decode import (
_tq_full_dequant_kv,
_use_fp8_e4b15,
triton_turboquant_decode_attention,
)
from vllm.v1.attention.ops.triton_turboquant_store import triton_turboquant_store
_HAS_FLASH_ATTN = is_flash_attn_varlen_func_available()
if _HAS_FLASH_ATTN:
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
# Continuation prefill: for small continuation chunks (q_len ≤ threshold),
# use the TQ decode kernel directly instead of full-dequant + flash_attn.
# do_kv_cache_update already stored all tokens to TQ cache, so the decode
# kernel can read them efficiently. This avoids O(cached_len) dequant work
# per continuation, eliminating the O(N²/chunk_size) collapse at long context.
_CONTINUATION_DECODE_THRESHOLD = 128
def _build_hadamard(d: int, device_str: str) -> torch.Tensor:
"""Orthonormal Hadamard matrix (Sylvester construction), cached per (d, device).
Precomputed D×D matrix enables matmul-based WHT — single cuBLAS GEMM
instead of log2(D) butterfly kernel launches. 64KB for D=128.
"""
# Normalize device string so "cuda" and "cuda:0" hit the same cache entry.
return _build_hadamard_cached(d, str(torch.device(device_str)))
@functools.cache
def _build_hadamard_cached(d: int, device_str: str) -> torch.Tensor:
H = torch.tensor([[1.0]])
while H.shape[0] < d:
H = torch.cat([torch.cat([H, H], 1), torch.cat([H, -H], 1)], 0)
return (H / math.sqrt(d)).to(torch.device(device_str))
class TurboQuantAttentionBackend(AttentionBackend):
"""Attention backend using TurboQuant KV-cache compression."""
accept_output_buffer: bool = True
forward_includes_kv_cache_update: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"turboquant_k8v4",
"turboquant_4bit_nc",
"turboquant_k3v4_nc",
"turboquant_3bit_nc",
]
@staticmethod
def get_name() -> str:
return "TURBOQUANT"
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [16, 32, 64, 128]
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
return attn_type == AttentionType.DECODER
@classmethod
def supports_per_head_quant_scales(cls) -> bool:
return False
@staticmethod
def get_impl_cls() -> type["TurboQuantAttentionImpl"]:
return TurboQuantAttentionImpl
@staticmethod
def get_builder_cls() -> type["TurboQuantMetadataBuilder"]:
return TurboQuantMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "turboquant_4bit_nc",
) -> tuple[int, ...]:
"""Combined K+V cache shape — no leading 2 dimension.
Standard attention backends use (2, num_blocks, block_size, num_kv_heads,
head_dim) with a leading 2 to separate K and V. TurboQuant packs K+V
into a single interleaved slot per head per position, so the cache is:
(num_blocks, block_size, num_kv_heads, slot_size_aligned)
Each slot = [key_packed | value_packed | padding].
This is safe because TQ has its own get_kv_cache_shape override and
never shares cache tensors with other backends. Layers that fall back
to native dtype via kv_cache_dtype_skip_layers get their own
standard-shaped cache allocation.
head_size is the model's real head_dim. slot_size_aligned is computed
from the TQ config to ensure correct cache allocation for all head dims.
"""
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype_str, head_size)
return (num_blocks, block_size, num_kv_heads, tq_config.slot_size_aligned)
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
if kv_cache_dtype is None:
return False
return kv_cache_dtype.startswith("turboquant_")
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
# head_size from spec is effective_head_size (padded_slot//2),
# not the model's actual head_dim. Accept any positive value.
return head_size > 0
@dataclass
class TurboQuantMetadata(AttentionMetadata):
"""Metadata for TurboQuant attention."""
seq_lens: torch.Tensor # (num_reqs,) — total context length per request
slot_mapping: torch.Tensor # (num_tokens,) — cache slot for each token
block_table: torch.Tensor # (num_reqs, max_num_blocks)
query_start_loc: torch.Tensor # (num_reqs + 1,) — cu_seqlens for queries
num_actual_tokens: int = 0 # actual tokens (excluding padding)
max_query_len: int = 0 # longest query in batch
max_seq_len: int = 0 # longest context in batch
is_prefill: bool = False
num_decodes: int = 0 # number of decode requests (first in batch)
num_decode_tokens: int = 0 # tokens from decode requests
class TurboQuantMetadataBuilder(AttentionMetadataBuilder[TurboQuantMetadata]):
"""Builds TurboQuantMetadata from scheduler output."""
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec, layer_names, vllm_config, device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self._init_reorder_batch_threshold(1, supports_spec_as_decode=False)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> TurboQuantMetadata:
attn_metadata = self.build(0, common_attn_metadata)
# Set seq_lens to 1 so CUDA graph capture is fast
# (real seq_lens are filled at replay time).
attn_metadata.seq_lens.fill_(1)
return attn_metadata
def build(self, common_prefix_len, common_attn_metadata, fast_build=False):
"""Build TurboQuantMetadata from common attention metadata."""
cam = common_attn_metadata
# With reorder_batch_threshold=1, the model runner guarantees
# decodes come first in the batch. split_decodes_and_prefills
# finds the boundary (operates on CPU tensors — no GPU sync).
assert self.reorder_batch_threshold is not None
num_decodes, num_prefills, num_decode_tokens, _ = split_decodes_and_prefills(
cam, decode_threshold=self.reorder_batch_threshold
)
return TurboQuantMetadata(
seq_lens=cam.seq_lens,
slot_mapping=cam.slot_mapping,
block_table=cam.block_table_tensor,
query_start_loc=cam.query_start_loc,
num_actual_tokens=cam.num_actual_tokens,
max_query_len=cam.max_query_len,
max_seq_len=cam.max_seq_len,
is_prefill=(cam.max_query_len > 1),
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
)
class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
"""TurboQuant attention implementation.
Vectorized PyTorch: batch quantize/store, vectorized bit-unpack
decode with einsum scores and value gather.
"""
supports_quant_query_input: bool = False
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
sliding_window: int | None = None,
kv_cache_dtype: str = "auto",
logits_soft_cap: float | None = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
**kwargs,
):
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_kv_groups = num_heads // self.num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
self.tq_config = TurboQuantConfig.from_cache_dtype(kv_cache_dtype, head_size)
# Pre-compute kernel constants from config (avoid repeated arithmetic)
cfg = self.tq_config
self._mse_bytes = (
math.ceil(head_size * cfg.key_mse_bits / 8)
if not cfg.key_fp8
else head_size
)
self._val_data_bytes = math.ceil(head_size * cfg.effective_value_quant_bits / 8)
self._n_centroids = cfg.n_centroids if not cfg.key_fp8 else 1
# Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph,
# and benchmarks show no regression vs dynamic in eager mode).
vllm_config = get_current_vllm_config()
self.max_num_kv_splits = (
vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph
)
def _ensure_on_device(self, layer, device):
"""One-time derivation of TQ buffers (rotation matrices, midpoints).
Registered buffers (_tq_signs, _tq_centroids) are already on the
correct device via register_buffer + model.to(device).
"""
if not hasattr(layer, "_tq_cached"):
D = layer._tq_signs.shape[0]
signs = layer._tq_signs.to(device=device, dtype=torch.float32)
# WHT rotation: orthonormal + self-inverse, enabling future
# in-kernel butterfly fusion and trivial inverse for continuation.
H = _build_hadamard(D, str(device))
layer._tq_PiT = (signs.unsqueeze(1) * H).contiguous()
layer._tq_Pi = layer._tq_PiT.T.contiguous()
c = layer._tq_centroids.to(device=device, dtype=torch.float32)
# Precompute midpoints for threshold-based quantization
c_sorted, _ = c.sort()
layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2
# Decode buffers (_tq_mid_o_buf, _tq_output_buf, _tq_lse_buf)
# are pre-allocated via register_buffer in Attention.__init__
# and moved to GPU by model.to(device) — no allocation needed
# here. The memory profiler sees them before KV cache sizing.
layer._tq_cached = True
def do_kv_cache_update(
self,
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""Store compressed K/V into the combined TQ cache.
Called as a separate custom op (unified_kv_cache_update) BEFORE
the attention forward, matching FlashAttention's split pattern.
slot_mapping is already sliced to num_actual_tokens by the caller.
"""
N = slot_mapping.shape[0]
if N <= 0:
return
device = key.device
self._ensure_on_device(layer, device)
k = key[:N].view(N, self.num_kv_heads, self.head_size)
v = value[:N].view(N, self.num_kv_heads, self.head_size)
self._store_kv(k, v, kv_cache, slot_mapping, layer)
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: "TurboQuantMetadata",
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
num_tokens = query.shape[0]
if output is None:
output = torch.zeros(
num_tokens,
self.num_heads * self.head_size,
dtype=query.dtype,
device=query.device,
)
if attn_metadata is None:
return output.fill_(0)
# Slice to actual tokens
N = attn_metadata.num_actual_tokens
if N <= 0:
return output.fill_(0)
q = query[:N].view(N, self.num_heads, self.head_size)
# Get TQ buffers, ensure on device (one-time migration).
# Use Any-typed alias for dynamic _tq_* attrs set by _ensure_on_device.
tq_layer: Any = layer
device = q.device
self._ensure_on_device(tq_layer, device)
Pi = tq_layer._tq_Pi
PiT = tq_layer._tq_PiT
centroids = tq_layer._tq_centroids
# Compute attention (KV cache was already updated by do_kv_cache_update)
# With reorder_batch_threshold=1, decodes come first in the batch.
# num_decodes/num_decode_tokens from metadata give the split point.
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
if not attn_metadata.is_prefill:
# Pure decode batch — fast path
attn_out = self._decode_attention(
q, kv_cache, attn_metadata, Pi, centroids, PiT, layer
)
elif num_decodes == 0:
# Pure prefill batch
k = key[:N].view(N, self.num_kv_heads, self.head_size)
v = value[:N].view(N, self.num_kv_heads, self.head_size)
attn_out = self._prefill_attention(
q,
k,
v,
kv_cache,
attn_metadata,
Pi,
centroids,
PiT,
layer=layer,
)
else:
# Mixed batch: decodes first (guaranteed by reorder_batch).
attn_out = torch.zeros(
N, self.num_heads, self.head_size, device=device, dtype=q.dtype
)
# --- Decode portion (first num_decodes requests) ---
# Use full-batch max_seq_len as safe upper bound (no GPU sync).
decode_meta = TurboQuantMetadata(
seq_lens=attn_metadata.seq_lens[:num_decodes],
slot_mapping=attn_metadata.slot_mapping[:num_decode_tokens],
block_table=attn_metadata.block_table[:num_decodes],
query_start_loc=attn_metadata.query_start_loc[: num_decodes + 1],
num_actual_tokens=num_decode_tokens,
max_query_len=1,
max_seq_len=attn_metadata.max_seq_len,
is_prefill=False,
)
attn_out[:num_decode_tokens] = self._decode_attention(
q[:num_decode_tokens], kv_cache, decode_meta, Pi, centroids, PiT, layer
)
# --- Prefill portion (remaining requests) ---
# CRITICAL: use prefill-specific max_seq_len so flash_attn's
# fast path (max_query_len == max_seq_len) triggers for
# first-chunk prefills. Using full-batch max_seq_len breaks
# this because decode requests inflate max_seq_len.
prefill_seq_lens = attn_metadata.seq_lens[num_decodes:]
# Use CPU-side max to avoid GPU→CPU sync from .item()
prefill_max_seq = max(attn_metadata.seq_lens[num_decodes:].tolist())
prefill_qsl = (
attn_metadata.query_start_loc[num_decodes:] - num_decode_tokens
)
prefill_meta = TurboQuantMetadata(
seq_lens=prefill_seq_lens,
slot_mapping=attn_metadata.slot_mapping[num_decode_tokens:N],
block_table=attn_metadata.block_table[num_decodes:],
query_start_loc=prefill_qsl,
num_actual_tokens=N - num_decode_tokens,
max_query_len=attn_metadata.max_query_len,
max_seq_len=prefill_max_seq,
is_prefill=True,
)
k = key[:N].view(N, self.num_kv_heads, self.head_size)
v = value[:N].view(N, self.num_kv_heads, self.head_size)
attn_out[num_decode_tokens:] = self._prefill_attention(
q[num_decode_tokens:],
k[num_decode_tokens:],
v[num_decode_tokens:],
kv_cache,
prefill_meta,
Pi,
centroids,
PiT,
layer=layer,
)
# Write into output buffer: attn_out is (N, Hq, D)
# output may be 2D (N, Hq*D) or 3D (N, Hq, D)
if output.ndim == 3:
output[:N] = attn_out.to(output.dtype)
else:
output[:N] = attn_out.reshape(N, -1).to(output.dtype)
return output
# ------------------------------------------------------------------ #
# Store K/V into combined cache (vectorized) #
# ------------------------------------------------------------------ #
def _store_kv(
self,
key: torch.Tensor, # (N, Hk, D)
value: torch.Tensor, # (N, Hk, D)
kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
slot_mapping: torch.Tensor,
layer: Any,
):
"""Quantize + store via fused Triton kernel."""
triton_turboquant_store(
key,
value,
kv_cache,
slot_mapping,
layer._tq_PiT,
layer._tq_midpoints,
mse_bits=self.tq_config.key_mse_bits,
key_packed_size=self.tq_config.key_packed_size,
value_quant_bits=self.tq_config.effective_value_quant_bits,
key_fp8=self.tq_config.key_fp8,
)
# ------------------------------------------------------------------ #
# Prefill: SDPA on raw Q/K/V with causal mask #
# ------------------------------------------------------------------ #
def _prefill_attention(
self,
query: torch.Tensor, # (N, Hq, D)
key: torch.Tensor, # (N, Hk, D)
value: torch.Tensor, # (N, Hk, D)
kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
attn_metadata: TurboQuantMetadata,
Pi: torch.Tensor,
centroids: torch.Tensor,
PiT: torch.Tensor | None = None,
layer: Any = None,
) -> torch.Tensor:
N, Hq, D = query.shape
# Fast path: use flash_attn for first-chunk prefills (all K/V in batch).
# max_query_len == max_seq_len means no request has prior cached KV.
# Both are Python ints — no GPU sync.
if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len:
output = torch.empty(N, Hq, D, device=query.device, dtype=query.dtype)
flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=attn_metadata.query_start_loc,
cu_seqlens_k=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=self.scale,
causal=True,
out=output,
)
return output
# Continuation or no flash_attn: per-request attention.
# For continuation chunks (seq_len > q_len), we must attend to
# previously cached K/V from the TQ cache, not just the current
# chunk's raw K/V.
Hk = key.shape[1]
use_gqa = Hk < Hq
query_start_loc = attn_metadata.query_start_loc
num_reqs = query_start_loc.shape[0] - 1
output = torch.zeros(N, Hq, D, device=query.device, dtype=query.dtype)
# Convert to Python lists once (single CPU-GPU sync) instead of
# per-request .item() calls that each force a sync.
qsl = query_start_loc.tolist()
seq_lens_list = attn_metadata.seq_lens.tolist()
# Pre-allocate cu_seqlens for single-request flash_attn calls
# to avoid per-request host→device tensor creation.
_cu_2 = torch.zeros(2, device=query.device, dtype=torch.int32)
for i in range(num_reqs):
q_start = qsl[i]
q_end = qsl[i + 1]
q_len = q_end - q_start
if q_len <= 0:
continue
seq_len = seq_lens_list[i]
q_seq = query[q_start:q_end] # (q_len, Hq, D)
k_seq = key[q_start:q_end] # (q_len, Hk, D)
v_seq = value[q_start:q_end] # (q_len, Hk, D)
if q_len == seq_len:
# First-chunk prefill: all K/V are in the current batch.
if _HAS_FLASH_ATTN:
out = torch.empty_like(q_seq)
_cu_2[1] = q_len
cu = _cu_2
flash_attn_varlen_func(
q=q_seq,
k=k_seq,
v=v_seq,
cu_seqlens_q=cu,
cu_seqlens_k=cu,
max_seqlen_q=q_len,
max_seqlen_k=q_len,
softmax_scale=self.scale,
causal=True,
out=out,
)
else:
q_t = q_seq.transpose(0, 1).contiguous()
k_t = k_seq.transpose(0, 1).contiguous()
v_t = v_seq.transpose(0, 1).contiguous()
out = F.scaled_dot_product_attention(
q_t,
k_t,
v_t,
is_causal=True,
scale=self.scale,
enable_gqa=use_gqa,
).transpose(0, 1)
output[q_start:q_end] = out.to(query.dtype)
else:
# Continuation chunk: tokens already stored to TQ cache
# by do_kv_cache_update. Use decode kernel directly to
# avoid O(cached_len) full-dequant per continuation.
# For large continuations, fall back to _continuation_prefill.
cached_len = seq_len - q_len
if q_len <= _CONTINUATION_DECODE_THRESHOLD:
# Fast path: treat each query as a decode request
# with incremental seq_lens for causal masking.
synth_seq_lens = torch.arange(
cached_len + 1,
seq_len + 1,
device=query.device,
dtype=attn_metadata.seq_lens.dtype,
)
synth_bt = attn_metadata.block_table[i : i + 1].expand(q_len, -1)
out = triton_turboquant_decode_attention(
query=q_seq,
kv_cache=kv_cache,
block_table=synth_bt,
seq_lens=synth_seq_lens,
Pi=Pi,
centroids=centroids,
scale=self.scale,
mse_bits=self.tq_config.key_mse_bits,
key_packed_size=self.tq_config.key_packed_size,
value_quant_bits=(self.tq_config.effective_value_quant_bits),
key_fp8=self.tq_config.key_fp8,
norm_correction=self.tq_config.norm_correction,
PiT=PiT,
)
else:
# Large continuation: dequant cached K/V and use
# flash_attn for better throughput.
out = self._continuation_prefill(
layer,
q_seq,
k_seq,
v_seq,
kv_cache,
attn_metadata.block_table[i : i + 1],
cached_len,
seq_len,
Pi,
centroids,
)
output[q_start:q_end] = out.to(query.dtype)
return output
def _continuation_prefill(
self,
layer: Any,
query: torch.Tensor, # (q_len, Hq, D)
key_chunk: torch.Tensor, # (q_len, Hk, D)
val_chunk: torch.Tensor, # (q_len, Hk, D)
kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
block_table: torch.Tensor, # (1, max_num_blocks)
cached_len: int,
seq_len: int,
Pi: torch.Tensor,
centroids: torch.Tensor,
) -> torch.Tensor:
"""Handle continuation chunk by dequanting cached K/V from TQ cache.
Dequants previously cached K/V, concatenates with the current
chunk's raw K/V, then runs flash_attn with causal masking.
"""
q_len, Hq, D = query.shape
Hk = key_chunk.shape[1]
device = query.device
block_size = kv_cache.shape[1]
BLOCK_D = triton.next_power_of_2(D)
mse_bytes = self._mse_bytes
val_data_bytes = self._val_data_bytes
# Dequant cached K/V from TQ cache
# Allocate slightly over to align to block_size for the grid.
# Reuse cached buffers to avoid per-call allocation (~16MB at 8K).
alloc_len = math.ceil(cached_len / block_size) * block_size
buf_shape = (1, Hk, alloc_len, D)
k_buf = getattr(layer, "_tq_k_dequant_buf", None)
if k_buf is None or k_buf.shape[2] < alloc_len:
k_buf = torch.empty(buf_shape, dtype=torch.float16, device=device)
v_buf = torch.empty(buf_shape, dtype=torch.float16, device=device)
layer._tq_k_dequant_buf = k_buf
layer._tq_v_dequant_buf = v_buf
else:
v_buf = layer._tq_v_dequant_buf
k_cached = k_buf[:, :, :alloc_len, :].zero_()
v_cached = v_buf[:, :, :alloc_len, :].zero_()
grid = (alloc_len, 1 * Hk)
_tq_full_dequant_kv[grid](
kv_cache,
block_table,
centroids,
k_cached,
v_cached,
k_cached.stride(0),
k_cached.stride(1),
k_cached.stride(2),
v_cached.stride(0),
v_cached.stride(1),
v_cached.stride(2),
kv_cache.stride(0),
kv_cache.stride(1),
kv_cache.stride(2),
block_table.stride(0),
HEAD_DIM=D,
BLOCK_SIZE=block_size,
NUM_KV_HEADS=Hk,
MSE_BYTES=mse_bytes,
KPS=self.tq_config.key_packed_size,
VQB=self.tq_config.effective_value_quant_bits,
VAL_DATA_BYTES=val_data_bytes,
MSE_BITS=self.tq_config.key_mse_bits,
KEY_FP8=1 if self.tq_config.key_fp8 else 0,
BLOCK_D=BLOCK_D,
NORM_CORRECTION=1 if self.tq_config.norm_correction else 0,
FP8_E4B15=_use_fp8_e4b15(device.index or 0),
num_warps=4,
)
# Inverse-rotate MSE keys back to original space
if not self.tq_config.key_fp8:
k_flat = k_cached[0, :, :cached_len, :].reshape(-1, D).float()
k_flat = k_flat @ Pi
k_cached_trim = (
k_flat.to(torch.float16).reshape(Hk, cached_len, D).transpose(0, 1)
) # (cached_len, Hk, D)
else:
k_cached_trim = (
k_cached[0, :, :cached_len, :].transpose(0, 1).contiguous()
) # (cached_len, Hk, D)
v_cached_trim = (
v_cached[0, :, :cached_len, :].transpose(0, 1).contiguous()
) # (cached_len, Hk, D)
# Concatenate cached + current chunk K/V (match query dtype)
qdtype = query.dtype
k_full = torch.cat([k_cached_trim.to(qdtype), key_chunk], dim=0)
v_full = torch.cat([v_cached_trim.to(qdtype), val_chunk], dim=0)
# Attention: q_len queries attending to seq_len K/V with causal mask
if _HAS_FLASH_ATTN:
output = torch.empty(q_len, Hq, D, device=device, dtype=query.dtype)
cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32)
flash_attn_varlen_func(
q=query,
k=k_full,
v=v_full,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=seq_len,
softmax_scale=self.scale,
causal=True,
out=output,
)
return output
else:
# SDPA fallback: expand KV for GQA, build causal mask
q_t = query.transpose(0, 1).unsqueeze(0) # (1, Hq, q_len, D)
k_t = k_full.transpose(0, 1).unsqueeze(0) # (1, Hk, seq_len, D)
v_t = v_full.transpose(0, 1).unsqueeze(0) # (1, Hk, seq_len, D)
# Build causal mask: query position p can attend to K position j
# where j <= cached_len + p (p is 0-indexed within chunk)
q_pos = torch.arange(q_len, device=device).unsqueeze(1) + cached_len
k_pos = torch.arange(seq_len, device=device).unsqueeze(0)
mask = k_pos <= q_pos # (q_len, seq_len)
out = F.scaled_dot_product_attention(
q_t,
k_t,
v_t,
attn_mask=mask,
scale=self.scale,
enable_gqa=(Hk < Hq),
) # (1, Hq, q_len, D)
return out[0].transpose(0, 1) # (q_len, Hq, D)
# ------------------------------------------------------------------ #
# Decode: Triton TQ decode attention #
# ------------------------------------------------------------------ #
def _decode_attention(
self,
query: torch.Tensor, # (B, Hq, D)
kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
attn_metadata: TurboQuantMetadata,
Pi: torch.Tensor,
centroids: torch.Tensor,
PiT: torch.Tensor | None = None,
layer: torch.nn.Module | None = None,
) -> torch.Tensor:
# Grab cached decode buffers from the layer (lazily allocated).
mid_o_buf = output_buf = lse_buf = None
if layer is not None:
mid_o_buf = getattr(layer, "_tq_mid_o_buf", None)
output_buf = getattr(layer, "_tq_output_buf", None)
lse_buf = getattr(layer, "_tq_lse_buf", None)
result = triton_turboquant_decode_attention(
query=query,
kv_cache=kv_cache,
block_table=attn_metadata.block_table,
seq_lens=attn_metadata.seq_lens,
Pi=Pi,
centroids=centroids,
scale=self.scale,
mse_bits=self.tq_config.key_mse_bits,
key_packed_size=self.tq_config.key_packed_size,
value_quant_bits=self.tq_config.effective_value_quant_bits,
key_fp8=self.tq_config.key_fp8,
norm_correction=self.tq_config.norm_correction,
PiT=PiT,
mid_o_buf=mid_o_buf,
output_buf=output_buf,
lse_buf=lse_buf,
buf_holder=layer,
max_num_kv_splits=self.max_num_kv_splits,
)
return result
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Triton fused TurboQuant decode attention.
Decode path: Triton stage1 (split-KV tiled attention scoring + value
accumulation) + stage2 (log-sum-exp reduction across splits).
Supports FP8 (E4M3) keys, 3-bit and 4-bit uniform quantized values.
"""
import math
from typing import Any
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.attention.ops.triton_decode_attention import (
_fwd_kernel_stage2,
)
_FP8_E4B15: dict[int, int] = {}
def _use_fp8_e4b15(device: int = 0) -> int:
"""Return 1 if device needs fp8e4b15 (Ampere/Ada, SM < 8.9), else 0."""
if device not in _FP8_E4B15:
cap = torch.cuda.get_device_capability(device)
_FP8_E4B15[device] = 1 if cap < (8, 9) else 0
return _FP8_E4B15[device]
# ---------------------------------------------------------------------------
# Stage 1: Fused TQ score + value accumulation (BLOCK_KV tiled)
# ---------------------------------------------------------------------------
@triton.jit
def _tq_decode_stage1(
# Precomputed query projection
Q_rot_ptr, # [B, Hq, D] float32
# Compressed KV cache (combined K+V)
KV_cache_ptr, # [num_blocks, block_size, Hk, padded_slot] uint8
# Block table and sequence info
Block_table_ptr, # [B, max_num_blocks] int32
Seq_lens_ptr, # [B] int32
# TQ parameters
Centroids_ptr, # [n_centroids] float32
# Output (intermediate for stage2)
Mid_o_ptr, # [B, Hq, NUM_KV_SPLITS, D+1] float32
# Strides
stride_qb,
stride_qh, # Q strides: [B, Hq, D]
stride_cache_block,
stride_cache_pos,
stride_cache_head, # KV cache
stride_bt_b, # block_table stride per batch
stride_mid_b,
stride_mid_h,
stride_mid_s, # mid_o strides
# Constexpr dims
NUM_KV_HEADS: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_SIZE: tl.constexpr, # KV cache block_size (pages)
NUM_KV_SPLITS: tl.constexpr,
KV_GROUP_SIZE: tl.constexpr, # Hq // Hk
# TQ layout constants
MSE_BITS: tl.constexpr, # 3 or 4
MSE_BYTES: tl.constexpr, # ceil(D * mse_bits / 8)
KPS: tl.constexpr, # key_packed_size
VQB: tl.constexpr, # value_quant_bits (4 or 8=FP8)
VAL_DATA_BYTES: tl.constexpr, # ceil(D * vqb / 8) or D for FP8
# Score constants
ATTN_SCALE: tl.constexpr, # 1/sqrt(D)
# Block tile sizes
BLOCK_D: tl.constexpr, # next_power_of_2(HEAD_DIM)
BLOCK_KV: tl.constexpr, # tokens per tile (16)
KEY_FP8: tl.constexpr, # 1 if K is stored as FP8
NORM_CORRECTION: tl.constexpr = 0, # 1 = re-normalize centroids
FP8_E4B15: tl.constexpr = 0, # 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
):
bid = tl.program_id(0) # batch index
hid = tl.program_id(1) # q_head index
sid = tl.program_id(2) # kv_split index
kv_head = hid // KV_GROUP_SIZE
# Sequence length for this batch
seq_len = tl.load(Seq_lens_ptr + bid)
# KV split range
split_len = tl.cdiv(seq_len, NUM_KV_SPLITS)
split_start = split_len * sid
split_end = tl.minimum(split_start + split_len, seq_len)
if split_start >= split_end:
return
# Dimension offsets
d_offs = tl.arange(0, BLOCK_D)
d_mask = d_offs < HEAD_DIM
kv_range = tl.arange(0, BLOCK_KV)
# Load query vector: q_rot — [BLOCK_D] float32
q_base = bid * stride_qb + hid * stride_qh
q_rot = tl.load(Q_rot_ptr + q_base + d_offs, mask=d_mask, other=0.0).to(tl.float32)
# Precompute byte/bit index vectors for MSE gather loads
if not KEY_FP8:
mse_bit_off = d_offs * MSE_BITS
mse_byte_idx = mse_bit_off // 8
mse_bit_shift = mse_bit_off % 8
mse_mask = (1 << MSE_BITS) - 1
# Precompute value bit/byte index vectors (loop-invariant)
if VQB == 3:
val_bit_off = d_offs * 3
val_byte_idx = val_bit_off // 8
val_bit_shift = val_bit_off % 8
# Online softmax accumulators
m_prev = -float("inf")
l_prev = 0.0
acc = tl.zeros([BLOCK_D], dtype=tl.float32)
bt_base = bid * stride_bt_b
# ================================================================
# TILED LOOP: process BLOCK_KV tokens per iteration
# ================================================================
for start_n in range(split_start, split_end, BLOCK_KV):
kv_offs = start_n + kv_range
kv_mask = kv_offs < split_end
page_idx = kv_offs // BLOCK_SIZE
page_off = kv_offs % BLOCK_SIZE
block_nums = tl.load(
Block_table_ptr + bt_base + page_idx,
mask=kv_mask,
other=0,
)
slot_bases = (
block_nums * stride_cache_block
+ page_off * stride_cache_pos
+ kv_head * stride_cache_head
)
# ============================================================
# COMPUTE ATTENTION SCORES: [BLOCK_KV]
# ============================================================
if KEY_FP8:
k_addrs = slot_bases[:, None] + d_offs[None, :]
k_raw = tl.load(
KV_cache_ptr + k_addrs,
mask=kv_mask[:, None] & d_mask[None, :],
other=0,
)
if FP8_E4B15:
k_float = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32)
else:
k_float = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32)
scores = (
tl.sum(
tl.where(d_mask[None, :], q_rot[None, :] * k_float, 0.0),
axis=1,
)
* ATTN_SCALE
)
scores = tl.where(kv_mask, scores, -float("inf"))
else:
# MSE unpack + norms
mse_addrs0 = slot_bases[:, None] + mse_byte_idx[None, :]
mse_raw0 = tl.load(
KV_cache_ptr + mse_addrs0,
mask=kv_mask[:, None] & d_mask[None, :],
other=0,
).to(tl.int32)
mse_raw1 = tl.load(
KV_cache_ptr + mse_addrs0 + 1,
mask=kv_mask[:, None] & d_mask[None, :],
other=0,
).to(tl.int32)
raw16 = mse_raw0 | (mse_raw1 << 8)
mse_idx = (raw16 >> mse_bit_shift[None, :]) & mse_mask
# Centroid gather + dot product
c_vals = tl.load(
Centroids_ptr + mse_idx,
mask=kv_mask[:, None] & d_mask[None, :],
other=0.0,
)
# Norm correction: re-normalize centroid vector to unit norm
if NORM_CORRECTION:
c_norm_sq = tl.sum(
tl.where(d_mask[None, :], c_vals * c_vals, 0.0),
axis=1,
)
c_inv_norm = 1.0 / tl.sqrt(c_norm_sq + 1e-16)
c_vals = c_vals * c_inv_norm[:, None]
term1 = tl.sum(
tl.where(d_mask[None, :], q_rot[None, :] * c_vals, 0.0),
axis=1,
)
# Load norms (fp16 -> fp32): norms are at MSE_BYTES offset
norm_bases = slot_bases + MSE_BYTES
n_lo = tl.load(KV_cache_ptr + norm_bases, mask=kv_mask, other=0).to(
tl.uint16
)
n_hi = tl.load(KV_cache_ptr + norm_bases + 1, mask=kv_mask, other=0).to(
tl.uint16
)
vec_norms = (n_lo | (n_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
scores = vec_norms * term1 * ATTN_SCALE
scores = tl.where(kv_mask, scores, -float("inf"))
# ============================================================
# ONLINE SOFTMAX UPDATE (block-level)
# ============================================================
n_e_max = tl.maximum(tl.max(scores, 0), m_prev)
re_scale = tl.exp(m_prev - n_e_max)
p = tl.exp(scores - n_e_max)
# ============================================================
# VALUE LOAD + DEQUANTIZE: [BLOCK_KV, BLOCK_D]
# ============================================================
val_bases = slot_bases + KPS
if VQB == 3:
val_addrs0 = val_bases[:, None] + val_byte_idx[None, :]
val_raw0 = tl.load(
KV_cache_ptr + val_addrs0,
mask=kv_mask[:, None] & d_mask[None, :],
other=0,
).to(tl.int32)
val_raw1 = tl.load(
KV_cache_ptr + val_addrs0 + 1,
mask=kv_mask[:, None] & d_mask[None, :],
other=0,
).to(tl.int32)
raw16 = val_raw0 | (val_raw1 << 8)
v_idx = ((raw16 >> val_bit_shift[None, :]) & 0x7).to(tl.float32)
sc_bases = val_bases + VAL_DATA_BYTES
sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to(
tl.uint16
)
sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to(
tl.uint16
)
v_scales = (
(sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
)
zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to(
tl.uint16
)
zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to(
tl.uint16
)
v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
values = v_idx * v_scales[:, None] + v_zeros[:, None]
else: # VQB == 4
vb_idx = d_offs // 2
vb_shift = (d_offs % 2) * 4
val_addrs = val_bases[:, None] + vb_idx[None, :]
val_raw = tl.load(
KV_cache_ptr + val_addrs,
mask=kv_mask[:, None] & d_mask[None, :],
other=0,
).to(tl.int32)
v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32)
sc_bases = val_bases + VAL_DATA_BYTES
sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to(
tl.uint16
)
sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to(
tl.uint16
)
v_scales = (
(sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
)
zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to(
tl.uint16
)
zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to(
tl.uint16
)
v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
values = v_idx * v_scales[:, None] + v_zeros[:, None]
# ============================================================
# WEIGHTED VALUE ACCUMULATION
# ============================================================
acc = acc * re_scale + tl.sum(p[:, None] * values, 0)
l_prev = l_prev * re_scale + tl.sum(p, 0)
m_prev = n_e_max
# Store partial result
out_base = bid * stride_mid_b + hid * stride_mid_h + sid * stride_mid_s
safe_l = tl.where(l_prev > 0.0, l_prev, 1.0)
tl.store(Mid_o_ptr + out_base + d_offs, acc / safe_l, mask=d_mask)
lse = m_prev + tl.log(safe_l)
tl.store(Mid_o_ptr + out_base + HEAD_DIM, lse)
# ---------------------------------------------------------------------------
# Pre-dequant kernel: Bulk dequant K (MSE+norms) and V to fp16
# ---------------------------------------------------------------------------
@triton.jit
def _tq_full_dequant_kv(
KV_cache_ptr,
Block_table_ptr,
Centroids_ptr,
K_out_ptr, # [B, Hk, max_seq, D] float16
V_out_ptr, # [B, Hk, max_seq, D] float16
stride_ko_b,
stride_ko_h,
stride_ko_s,
stride_vo_b,
stride_vo_h,
stride_vo_s,
stride_cache_block,
stride_cache_pos,
stride_cache_head,
stride_bt_b,
HEAD_DIM: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
NUM_KV_HEADS: tl.constexpr,
MSE_BYTES: tl.constexpr,
KPS: tl.constexpr,
VQB: tl.constexpr,
VAL_DATA_BYTES: tl.constexpr,
MSE_BITS: tl.constexpr,
KEY_FP8: tl.constexpr,
BLOCK_D: tl.constexpr,
NORM_CORRECTION: tl.constexpr = 0,
FP8_E4B15: tl.constexpr = 0, # 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
):
"""Full dequant: reconstruct K (MSE centroids * norm or FP8) and V to fp16."""
pos = tl.program_id(0)
bh = tl.program_id(1)
bid = bh // NUM_KV_HEADS
hid = bh % NUM_KV_HEADS
page_idx = pos // BLOCK_SIZE
page_off = pos % BLOCK_SIZE
block_num = tl.load(Block_table_ptr + bid * stride_bt_b + page_idx)
slot_base = (
block_num * stride_cache_block
+ page_off * stride_cache_pos
+ hid * stride_cache_head
)
d_offs = tl.arange(0, BLOCK_D)
d_mask = d_offs < HEAD_DIM
# === K dequant ===
ko_base = bid * stride_ko_b + hid * stride_ko_h + pos * stride_ko_s
if KEY_FP8:
k_raw = tl.load(KV_cache_ptr + slot_base + d_offs, mask=d_mask, other=0)
if FP8_E4B15:
k_recon = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32)
else:
k_recon = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32)
tl.store(K_out_ptr + ko_base + d_offs, k_recon.to(tl.float16), mask=d_mask)
else:
# MSE unpack (3-bit or 4-bit) + norms
mse_bit_off = d_offs * MSE_BITS
mse_byte_idx = mse_bit_off // 8
mse_bit_shift = mse_bit_off % 8
mse_umask = (1 << MSE_BITS) - 1
mse_raw0 = tl.load(
KV_cache_ptr + slot_base + mse_byte_idx, mask=d_mask, other=0
).to(tl.int32)
mse_raw1 = tl.load(
KV_cache_ptr + slot_base + mse_byte_idx + 1, mask=d_mask, other=0
).to(tl.int32)
raw16_key = mse_raw0 | (mse_raw1 << 8)
mse_idx = (raw16_key >> mse_bit_shift) & mse_umask
k_mse = tl.load(Centroids_ptr + mse_idx, mask=d_mask, other=0.0)
# Norm correction: re-normalize centroid vector to unit norm
if NORM_CORRECTION:
c_norm_sq = tl.sum(tl.where(d_mask, k_mse * k_mse, 0.0), axis=0)
c_inv_norm = 1.0 / tl.sqrt(c_norm_sq + 1e-16)
k_mse = k_mse * c_inv_norm
# Norms at MSE_BYTES offset (no QJL bytes)
norm_base = slot_base + MSE_BYTES
n_lo = tl.load(KV_cache_ptr + norm_base).to(tl.uint16)
n_hi = tl.load(KV_cache_ptr + norm_base + 1).to(tl.uint16)
vec_norm = (n_lo | (n_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
k_recon = vec_norm * k_mse
tl.store(K_out_ptr + ko_base + d_offs, k_recon.to(tl.float16), mask=d_mask)
# === V dequant ===
val_base = slot_base + KPS
if VQB == 4:
vb_idx = d_offs // 2
vb_shift = (d_offs % 2) * 4
val_raw = tl.load(KV_cache_ptr + val_base + vb_idx, mask=d_mask, other=0).to(
tl.int32
)
v_idx = ((val_raw >> vb_shift) & 0xF).to(tl.float32)
sc_base = val_base + VAL_DATA_BYTES
sc_lo = tl.load(KV_cache_ptr + sc_base).to(tl.uint16)
sc_hi = tl.load(KV_cache_ptr + sc_base + 1).to(tl.uint16)
v_scale = (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
zr_lo = tl.load(KV_cache_ptr + sc_base + 2).to(tl.uint16)
zr_hi = tl.load(KV_cache_ptr + sc_base + 3).to(tl.uint16)
v_zero = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
v_vals = v_idx * v_scale + v_zero
elif VQB == 3:
# 3-bit value unpack: 8 values per 3 bytes
val_bit_off = d_offs * 3
val_byte_idx = val_bit_off // 8
val_bit_shift = val_bit_off % 8
val_raw0 = tl.load(
KV_cache_ptr + val_base + val_byte_idx, mask=d_mask, other=0
).to(tl.int32)
val_raw1 = tl.load(
KV_cache_ptr + val_base + val_byte_idx + 1, mask=d_mask, other=0
).to(tl.int32)
raw16_val = val_raw0 | (val_raw1 << 8)
v_idx = ((raw16_val >> val_bit_shift) & 0x7).to(tl.float32)
sc_base = val_base + VAL_DATA_BYTES
sc_lo = tl.load(KV_cache_ptr + sc_base).to(tl.uint16)
sc_hi = tl.load(KV_cache_ptr + sc_base + 1).to(tl.uint16)
v_scale = (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
zr_lo = tl.load(KV_cache_ptr + sc_base + 2).to(tl.uint16)
zr_hi = tl.load(KV_cache_ptr + sc_base + 3).to(tl.uint16)
v_zero = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
v_vals = v_idx * v_scale + v_zero
else:
v_vals = tl.zeros([BLOCK_D], dtype=tl.float32)
vo_base = bid * stride_vo_b + hid * stride_vo_h + pos * stride_vo_s
tl.store(V_out_ptr + vo_base + d_offs, v_vals.to(tl.float16), mask=d_mask)
# ---------------------------------------------------------------------------
# Stage 2: Reuse from triton_decode_attention.py
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Launcher — cached constants + fused GEMM
# ---------------------------------------------------------------------------
_layout_cache: dict = {}
def _get_layout(D, mse_bits, value_quant_bits, key_packed_size):
"""Get cached layout constants."""
key = (D, mse_bits, value_quant_bits, key_packed_size)
cfg = _layout_cache.get(key)
if cfg is None:
val_data_bytes = math.ceil(D * value_quant_bits / 8)
cfg = {
"mse_bytes": math.ceil(D * mse_bits / 8),
"val_data_bytes": val_data_bytes,
"mse_bits": mse_bits,
"n_centroids": 2**mse_bits,
"BLOCK_D": triton.next_power_of_2(D),
}
_layout_cache[key] = cfg
return cfg
def triton_turboquant_decode_attention(
query: torch.Tensor, # [B, Hq, D] — original query
kv_cache: torch.Tensor, # [num_blocks, block_size, Hk, padded_slot] uint8
block_table: torch.Tensor, # [B, max_num_blocks] int32
seq_lens: torch.Tensor, # [B] int32
Pi: torch.Tensor, # [D, D] float32
centroids: torch.Tensor, # [n_centroids] float32
scale: float,
mse_bits: int,
key_packed_size: int,
value_quant_bits: int,
key_fp8: bool = False,
norm_correction: bool = False,
PiT: torch.Tensor | None = None, # [D, D] pre-computed Pi.T contiguous
# Pre-allocated buffers (optional, avoids per-call allocation)
mid_o_buf: torch.Tensor | None = None,
output_buf: torch.Tensor | None = None,
lse_buf: torch.Tensor | None = None,
buf_holder: Any = None,
max_num_kv_splits: int = 32, # fixed split count (must be constant for cudagraph)
) -> torch.Tensor:
"""Launch fused TQ decode attention (Triton stage1 + stage2).
Returns: output tensor [B, Hq, D] in query's dtype.
"""
B, Hq, D = query.shape
Hk = kv_cache.shape[2]
block_size = kv_cache.shape[1]
kv_group_size = Hq // Hk
device = query.device
cfg = _get_layout(D, mse_bits, value_quant_bits, key_packed_size)
# Compute q_rot = q @ Pi.T (rotated query for MSE key scoring)
# FP8 path: pass query directly (float16); kernel casts inline.
# MSE path: still needs external GEMM (cuBLAS), so q_rot is float32.
if key_fp8:
q_rot = query.contiguous()
else:
q_float = query.float()
if PiT is None:
PiT = Pi.T.contiguous()
q_rot = (q_float @ PiT).contiguous()
NUM_KV_SPLITS = max_num_kv_splits
if (
mid_o_buf is not None
and mid_o_buf.shape[0] >= B
and mid_o_buf.shape[2] >= NUM_KV_SPLITS
):
mid_o = mid_o_buf[:B, :Hq, :NUM_KV_SPLITS, :]
else:
mid_o = torch.empty(
B,
Hq,
NUM_KV_SPLITS,
D + 1,
dtype=torch.float32,
device=device,
)
if buf_holder is not None:
buf_holder._tq_mid_o_buf = mid_o
# Stage 1: split-KV tiled attention scoring + value accumulation
fp8_e4b15 = _use_fp8_e4b15(device.index or 0)
BLOCK_KV = 4
grid = (B, Hq, NUM_KV_SPLITS)
_tq_decode_stage1[grid](
q_rot,
kv_cache,
block_table,
seq_lens,
centroids,
mid_o,
q_rot.stride(0),
q_rot.stride(1),
kv_cache.stride(0),
kv_cache.stride(1),
kv_cache.stride(2),
block_table.stride(0),
mid_o.stride(0),
mid_o.stride(1),
mid_o.stride(2),
NUM_KV_HEADS=Hk,
HEAD_DIM=D,
BLOCK_SIZE=block_size,
NUM_KV_SPLITS=NUM_KV_SPLITS,
KV_GROUP_SIZE=kv_group_size,
MSE_BITS=mse_bits,
MSE_BYTES=cfg["mse_bytes"],
KPS=key_packed_size,
VQB=value_quant_bits,
VAL_DATA_BYTES=cfg["val_data_bytes"],
ATTN_SCALE=scale,
BLOCK_D=cfg["BLOCK_D"],
BLOCK_KV=BLOCK_KV,
KEY_FP8=1 if key_fp8 else 0,
NORM_CORRECTION=1 if norm_correction else 0,
FP8_E4B15=fp8_e4b15,
num_warps=1,
num_stages=1,
)
# Stage 2: Reduce across KV splits
if output_buf is not None and output_buf.shape[0] >= B:
output = output_buf[:B, :Hq, :D]
else:
output = torch.empty(B, Hq, D, dtype=torch.float32, device=device)
if buf_holder is not None:
buf_holder._tq_output_buf = output
if lse_buf is not None and lse_buf.shape[0] >= B:
lse = lse_buf[:B, :Hq]
else:
lse = torch.empty(B, Hq, dtype=torch.float32, device=device)
if buf_holder is not None:
buf_holder._tq_lse_buf = lse
grid2 = (B, Hq)
_fwd_kernel_stage2[grid2](
mid_o,
output,
lse,
seq_lens,
mid_o.stride(0),
mid_o.stride(1),
mid_o.stride(2),
output.stride(0),
output.stride(1),
lse.stride(0),
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=cfg["BLOCK_D"],
Lv=D,
num_warps=4,
num_stages=2,
)
return output.to(query.dtype)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused Triton kernels for TurboQuant KV store.
Two kernels:
1. _tq_fused_store_fp8: FP8 key scatter + value uniform quantization.
2. _tq_fused_store_mse: Fused binary-search bucketize + MSE index
packing + value quantization.
The launcher `triton_turboquant_store` selects the appropriate kernel.
"""
import math
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.attention.ops.triton_turboquant_decode import _use_fp8_e4b15
# ═══════════════════════════════════════════════════════════════════════
# Shared: value uniform quantization + pack + scale/zero store
# ═══════════════════════════════════════════════════════════════════════
@triton.jit
def _store_quantized_value(
Value_ptr,
KV_cache_ptr,
base, # pid * D offset into Value_ptr
slot_base, # byte offset into KV_cache_ptr for this slot+head
d_offs, # tl.arange(0, BLOCK_D)
d_mask, # d_offs < D
D: tl.constexpr,
KPS: tl.constexpr,
VQB: tl.constexpr,
VAL_DATA_BYTES: tl.constexpr,
BLOCK_D: tl.constexpr,
BLOCK_VAL: tl.constexpr,
BLOCK_GRP: tl.constexpr,
):
"""Uniform quantization of values to VQB bits, pack, and store with scale/zero."""
val_cache_offset = KPS
if VQB == 3:
val_vec = tl.load(Value_ptr + base + d_offs, mask=d_mask, other=0.0).to(
tl.float32
)
val_min = tl.min(tl.where(d_mask, val_vec, float("inf")), axis=0)
val_max = tl.max(tl.where(d_mask, val_vec, -float("inf")), axis=0)
v_scale = (val_max - val_min) / 7.0
v_scale = tl.where(v_scale > 1e-8, v_scale, 1e-8)
q_vals = tl.minimum(
tl.maximum(((val_vec - val_min) / v_scale + 0.5).to(tl.int32), 0), 7
)
grp_offs = tl.arange(0, BLOCK_GRP)
grp_mask = grp_offs < (D // 8)
q_grp = tl.reshape(q_vals, [BLOCK_GRP, 8])
shifts_3bit = tl.arange(0, 8) * 3
packed_24 = tl.sum(q_grp << shifts_3bit[None, :], axis=1)
b0 = (packed_24 & 0xFF).to(tl.uint8)
b1 = ((packed_24 >> 8) & 0xFF).to(tl.uint8)
b2 = ((packed_24 >> 16) & 0xFF).to(tl.uint8)
tl.store(
KV_cache_ptr + slot_base + val_cache_offset + grp_offs * 3,
b0,
mask=grp_mask,
)
tl.store(
KV_cache_ptr + slot_base + val_cache_offset + grp_offs * 3 + 1,
b1,
mask=grp_mask,
)
tl.store(
KV_cache_ptr + slot_base + val_cache_offset + grp_offs * 3 + 2,
b2,
mask=grp_mask,
)
sc_offset = val_cache_offset + VAL_DATA_BYTES
sc_f16 = v_scale.to(tl.float16)
sc_u16 = sc_f16.to(tl.uint16, bitcast=True)
tl.store(KV_cache_ptr + slot_base + sc_offset, (sc_u16 & 0xFF).to(tl.uint8))
tl.store(
KV_cache_ptr + slot_base + sc_offset + 1,
((sc_u16 >> 8) & 0xFF).to(tl.uint8),
)
zr_f16 = val_min.to(tl.float16)
zr_u16 = zr_f16.to(tl.uint16, bitcast=True)
tl.store(KV_cache_ptr + slot_base + sc_offset + 2, (zr_u16 & 0xFF).to(tl.uint8))
tl.store(
KV_cache_ptr + slot_base + sc_offset + 3,
((zr_u16 >> 8) & 0xFF).to(tl.uint8),
)
else: # VQB == 4
val_vec = tl.load(Value_ptr + base + d_offs, mask=d_mask, other=0.0).to(
tl.float32
)
val_min = tl.min(tl.where(d_mask, val_vec, float("inf")), axis=0)
val_max = tl.max(tl.where(d_mask, val_vec, -float("inf")), axis=0)
v_scale = (val_max - val_min) / 15.0
v_scale = tl.where(v_scale > 1e-8, v_scale, 1e-8)
# Quantize all D elements from register (no re-load)
q_all = tl.minimum(
tl.maximum(((val_vec - val_min) / v_scale + 0.5).to(tl.int32), 0), 15
)
# Reshape to pairs and pack two 4-bit values per byte
q_pairs = tl.reshape(q_all, [BLOCK_D // 2, 2])
shifts_4 = tl.arange(0, 2) * 4
packed_val = tl.sum((q_pairs & 0xF) << shifts_4[None, :], axis=1).to(tl.uint8)
val_offs = tl.arange(0, BLOCK_D // 2)
val_mask = val_offs < VAL_DATA_BYTES
tl.store(
KV_cache_ptr + slot_base + val_cache_offset + val_offs,
packed_val,
mask=val_mask,
)
sc_offset = val_cache_offset + VAL_DATA_BYTES
sc_f16 = v_scale.to(tl.float16)
sc_u16 = sc_f16.to(tl.uint16, bitcast=True)
tl.store(KV_cache_ptr + slot_base + sc_offset, (sc_u16 & 0xFF).to(tl.uint8))
tl.store(
KV_cache_ptr + slot_base + sc_offset + 1,
((sc_u16 >> 8) & 0xFF).to(tl.uint8),
)
zr_f16 = val_min.to(tl.float16)
zr_u16 = zr_f16.to(tl.uint16, bitcast=True)
tl.store(KV_cache_ptr + slot_base + sc_offset + 2, (zr_u16 & 0xFF).to(tl.uint8))
tl.store(
KV_cache_ptr + slot_base + sc_offset + 3,
((zr_u16 >> 8) & 0xFF).to(tl.uint8),
)
# ═══════════════════════════════════════════════════════════════════════
# FP8 key store + value uniform quantization
# ═══════════════════════════════════════════════════════════════════════
@triton.jit
def _tq_fused_store_fp8(
Key_ptr, # [NH, D] float16/bfloat16 — raw keys
Value_ptr, # [NH, D] float16/bfloat16 — raw values
KV_cache_ptr, # [total_bytes] uint8 (flattened view)
Slot_mapping_ptr, # [N] int32 — per-token slot indices
# Cache strides (for computing byte offsets)
stride_cache_block: tl.constexpr,
stride_cache_pos: tl.constexpr,
stride_cache_head: tl.constexpr,
# Dimensions
D: tl.constexpr,
H: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_D: tl.constexpr,
# TQ layout
KPS: tl.constexpr,
# Value quantization
VQB: tl.constexpr,
VAL_DATA_BYTES: tl.constexpr,
# Packing block sizes
BLOCK_VAL: tl.constexpr,
BLOCK_GRP: tl.constexpr = 16,
FP8_E4B15: tl.constexpr = 0, # 1 = e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
):
"""FP8 key cast+scatter + value uniform quantization."""
pid = tl.program_id(0)
token_idx = pid // H
head_idx = pid % H
slot = tl.load(Slot_mapping_ptr + token_idx)
if slot < 0:
return
blk = slot // BLOCK_SIZE
off = slot % BLOCK_SIZE
slot_base = (
blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head
)
base = pid * D
# ── FP8 KEY: cast to FP8 in-kernel and store ─────────────────
d_offs = tl.arange(0, BLOCK_D)
d_mask = d_offs < D
k_vals = tl.load(Key_ptr + base + d_offs, mask=d_mask, other=0.0)
k_fp8 = k_vals.to(tl.float8e4b15) if FP8_E4B15 else k_vals.to(tl.float8e4nv)
k_bytes = k_fp8.to(tl.uint8, bitcast=True)
tl.store(KV_cache_ptr + slot_base + d_offs, k_bytes, mask=d_mask)
# ── VALUE QUANTIZE + PACK ───────────────────────────────────────
_store_quantized_value(
Value_ptr,
KV_cache_ptr,
base,
slot_base,
d_offs,
d_mask,
D=D,
KPS=KPS,
VQB=VQB,
VAL_DATA_BYTES=VAL_DATA_BYTES,
BLOCK_D=BLOCK_D,
BLOCK_VAL=BLOCK_VAL,
BLOCK_GRP=BLOCK_GRP,
)
# ═══════════════════════════════════════════════════════════════════════
# Fused MSE store: bucketize + MSE index pack + norm store + value pack
# (eliminates 4 PyTorch kernel launches per layer vs pack-only kernel)
# ═══════════════════════════════════════════════════════════════════════
@triton.jit
def _tq_fused_store_mse(
# Post-rotation inputs
Y_ptr, # [NH, D] float32 — rotated normalized keys (x_hat @ PiT)
Norms_ptr, # [NH] float32 — key vector norms (||k||)
Value_ptr, # [NH, D] float32 — raw values
# Quantization tables
Midpoints_ptr, # [n_centroids-1] float32
# Cache and indexing
KV_cache_ptr, # [total_bytes] uint8 (flattened view)
Slot_mapping_ptr, # [N] int32 — per-token slot indices
# Cache strides
stride_cache_block: tl.constexpr,
stride_cache_pos: tl.constexpr,
stride_cache_head: tl.constexpr,
# Dimensions
D: tl.constexpr,
H: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_D: tl.constexpr,
# TQ layout
MSE_BYTES: tl.constexpr,
KPS: tl.constexpr,
# Value quantization
VQB: tl.constexpr,
VAL_DATA_BYTES: tl.constexpr,
# Packing block sizes
BLOCK_VAL: tl.constexpr,
# MSE params
MSE_BITS: tl.constexpr,
N_CENTROIDS: tl.constexpr,
BLOCK_GRP: tl.constexpr = 16,
):
"""Fused MSE quantize + pack + store.
Performs binary-search bucketize, MSE index packing, norm storage,
and value quantization in one kernel.
"""
pid = tl.program_id(0)
token_idx = pid // H
head_idx = pid % H
slot = tl.load(Slot_mapping_ptr + token_idx)
if slot < 0:
return
blk = slot // BLOCK_SIZE
off = slot % BLOCK_SIZE
slot_base = (
blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head
)
base = pid * D
d_offs = tl.arange(0, BLOCK_D)
d_mask = d_offs < D
# ── 1. BINARY SEARCH BUCKETIZE ───────────────────────────────────
# Midpoints are sorted (N_CENTROIDS-1 values); binary search finds
# insertion point in MSE_BITS iterations vs N_CENTROIDS-1 for linear.
y_vec = tl.load(Y_ptr + base + d_offs, mask=d_mask, other=0.0)
lo = tl.zeros([BLOCK_D], dtype=tl.int32)
hi = tl.full([BLOCK_D], N_CENTROIDS - 1, dtype=tl.int32)
for _ in range(MSE_BITS):
mid = (lo + hi) >> 1
# Clamp to valid midpoint index [0, N_CENTROIDS-2] for load safety;
# the search result (lo) is still correct since converged lanes
# don't change.
safe_mid = tl.minimum(mid, N_CENTROIDS - 2)
mid_val = tl.load(Midpoints_ptr + safe_mid, mask=d_mask, other=0.0)
lo = tl.where(y_vec >= mid_val, mid + 1, lo)
hi = tl.where(y_vec >= mid_val, hi, mid)
idx = tl.minimum(lo, N_CENTROIDS - 1)
# ── 2. PACK MSE INDICES from register idx ─────────────────────────
if MSE_BITS == 4:
idx_pairs = tl.reshape(idx, [BLOCK_D // 2, 2])
shifts_4 = tl.arange(0, 2) * 4
packed = tl.sum((idx_pairs & 0xF) << shifts_4[None, :], axis=1).to(tl.uint8)
mse_offs = tl.arange(0, BLOCK_D // 2)
mse_mask = mse_offs < MSE_BYTES
tl.store(KV_cache_ptr + slot_base + mse_offs, packed, mask=mse_mask)
elif MSE_BITS == 3:
grp_offs = tl.arange(0, BLOCK_GRP)
grp_mask = grp_offs < (D // 8)
idx_grp = tl.reshape(idx, [BLOCK_GRP, 8])
shifts_3 = tl.arange(0, 8) * 3
packed_24 = tl.sum((idx_grp & 0x7) << shifts_3[None, :], axis=1)
b0 = (packed_24 & 0xFF).to(tl.uint8)
b1 = ((packed_24 >> 8) & 0xFF).to(tl.uint8)
b2 = ((packed_24 >> 16) & 0xFF).to(tl.uint8)
tl.store(KV_cache_ptr + slot_base + grp_offs * 3, b0, mask=grp_mask)
tl.store(KV_cache_ptr + slot_base + grp_offs * 3 + 1, b1, mask=grp_mask)
tl.store(KV_cache_ptr + slot_base + grp_offs * 3 + 2, b2, mask=grp_mask)
# ── 3. STORE vec_norm (fp16, 2 bytes) ─────────────────────────────
norm_offset = MSE_BYTES
vn_f16 = tl.load(Norms_ptr + pid).to(tl.float16)
vn_u16 = vn_f16.to(tl.uint16, bitcast=True)
tl.store(KV_cache_ptr + slot_base + norm_offset, (vn_u16 & 0xFF).to(tl.uint8))
tl.store(
KV_cache_ptr + slot_base + norm_offset + 1, ((vn_u16 >> 8) & 0xFF).to(tl.uint8)
)
# ── 4. VALUE QUANTIZE + PACK ──────────────────────────────────────
_store_quantized_value(
Value_ptr,
KV_cache_ptr,
base,
slot_base,
d_offs,
d_mask,
D=D,
KPS=KPS,
VQB=VQB,
VAL_DATA_BYTES=VAL_DATA_BYTES,
BLOCK_D=BLOCK_D,
BLOCK_VAL=BLOCK_VAL,
BLOCK_GRP=BLOCK_GRP,
)
# ═══════════════════════════════════════════════════════════════════════
# Launcher
# ═══════════════════════════════════════════════════════════════════════
def triton_turboquant_store(
key: torch.Tensor, # [N, H, D] — raw keys (post-RoPE)
value: torch.Tensor, # [N, H, D] — raw values
kv_cache: torch.Tensor, # [num_blocks, block_size, Hk, padded_slot] uint8
slot_mapping: torch.Tensor, # [N] int32
PiT: torch.Tensor, # [D, D] float32
midpoints: torch.Tensor, # [n_centroids-1] float32
mse_bits: int,
key_packed_size: int,
value_quant_bits: int,
key_fp8: bool = False,
):
"""Launch TQ store kernel (FP8 or MSE path)."""
N, H, D = key.shape
NH = N * H
block_size = kv_cache.shape[1]
BLOCK_D = triton.next_power_of_2(D)
mse_bytes = math.ceil(D * mse_bits / 8)
n_centroids = 2**mse_bits
val_data_bytes = math.ceil(D * value_quant_bits / 8)
BLOCK_VAL = triton.next_power_of_2(val_data_bytes)
# Cache strides (element_size=1 for uint8, so stride in bytes = stride())
stride_block = kv_cache.stride(0)
stride_pos = kv_cache.stride(1)
stride_head = kv_cache.stride(2)
block_grp = triton.next_power_of_2(D // 8) if D >= 8 else 1
# ── FP8 PATH: in-kernel FP8 cast + scatter via fp8 kernel ──
if key_fp8:
k_flat = key.reshape(NH, D).contiguous()
v_flat = value.reshape(NH, D).contiguous()
fp8_e4b15 = _use_fp8_e4b15(key.device.index or 0)
grid = (NH,)
_tq_fused_store_fp8[grid](
k_flat,
v_flat,
kv_cache.view(-1),
slot_mapping,
stride_cache_block=stride_block,
stride_cache_pos=stride_pos,
stride_cache_head=stride_head,
D=D,
H=H,
BLOCK_SIZE=block_size,
BLOCK_D=BLOCK_D,
KPS=key_packed_size,
VQB=value_quant_bits,
VAL_DATA_BYTES=val_data_bytes,
BLOCK_VAL=BLOCK_VAL,
BLOCK_GRP=block_grp,
FP8_E4B15=fp8_e4b15,
num_warps=4,
num_stages=1,
)
return
# ── MSE PATH: external GEMM + fused bucketize/pack kernel ──
# Normalize + rotation GEMM externally (cuBLAS is faster than in-kernel)
k_flat = key.float().reshape(NH, D)
norms = k_flat.norm(dim=1, keepdim=True)
x_hat = k_flat / (norms + 1e-8)
y = x_hat @ PiT
v_flat = value.float().reshape(NH, D)
# Fused kernel: bucketize + MSE index pack + norm store + value pack
grid = (NH,)
_tq_fused_store_mse[grid](
y,
norms.squeeze(1),
v_flat,
midpoints,
kv_cache.view(-1),
slot_mapping,
stride_cache_block=stride_block,
stride_cache_pos=stride_pos,
stride_cache_head=stride_head,
D=D,
H=H,
BLOCK_SIZE=block_size,
BLOCK_D=BLOCK_D,
MSE_BYTES=mse_bytes,
KPS=key_packed_size,
VQB=value_quant_bits,
VAL_DATA_BYTES=val_data_bytes,
BLOCK_VAL=BLOCK_VAL,
MSE_BITS=mse_bits,
N_CENTROIDS=n_centroids,
BLOCK_GRP=block_grp,
num_warps=4,
num_stages=1,
)
......@@ -21,6 +21,7 @@ from vllm.v1.kv_cache_interface import (
MLAAttentionSpec,
SinkFullAttentionSpec,
SlidingWindowSpec,
TQFullAttentionSpec,
)
from vllm.v1.request import Request
......@@ -209,7 +210,7 @@ class SingleTypeKVCacheManager(ABC):
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
)
req_blocks.extend(allocated_blocks)
if type(self.kv_cache_spec) is FullAttentionSpec:
if type(self.kv_cache_spec) in (FullAttentionSpec, TQFullAttentionSpec):
self.new_block_ids.extend(b.block_id for b in allocated_blocks)
def allocate_new_blocks(
......@@ -237,7 +238,7 @@ class SingleTypeKVCacheManager(ABC):
else:
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
if type(self.kv_cache_spec) is FullAttentionSpec:
if type(self.kv_cache_spec) in (FullAttentionSpec, TQFullAttentionSpec):
self.new_block_ids.extend(b.block_id for b in new_blocks)
return new_blocks
......@@ -1114,6 +1115,7 @@ class SinkFullAttentionManager(FullAttentionManager):
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
TQFullAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
......
......@@ -245,6 +245,32 @@ class FullAttentionSpec(AttentionSpec):
)
@dataclass(frozen=True, kw_only=True)
class TQFullAttentionSpec(FullAttentionSpec):
"""FullAttentionSpec with TQ-aware page size.
Python equivalent of the C++ TQ4FullAttentionSpec. Overrides
real_page_size_bytes to use TQ slot bytes instead of the raw
head_size * dtype formula.
"""
tq_slot_size: int = 0
@property
def real_page_size_bytes(self) -> int:
if self.tq_slot_size > 0:
return self.block_size * self.num_kv_heads * self.tq_slot_size
return super().real_page_size_bytes
@classmethod
def merge(cls, specs: list[Self]) -> Self:
merged = super().merge(specs)
assert all(s.tq_slot_size == specs[0].tq_slot_size for s in specs), (
"All TQ layers in the same KV cache group must use the same tq_slot_size."
)
return replace(merged, tq_slot_size=specs[0].tq_slot_size)
@dataclass(frozen=True, kw_only=True)
class MLAAttentionSpec(FullAttentionSpec):
# TODO(Lucas/Chen): less hacky way to do this
......
......@@ -120,7 +120,7 @@ class KVBlockZeroer:
for group in attn_groups_iter:
spec = group.kv_cache_spec
if type(spec) is not FullAttentionSpec:
if not isinstance(spec, FullAttentionSpec):
continue
if group.kv_cache_group_id >= len(kernel_block_sizes):
continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment