# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """TurboQuant attention backend for vLLM. Prefill: Standard scaled dot-product attention on uncompressed K/V, then quantize K and store K+V into combined cache slot. Decode: Compute TQ attention scores from compressed cache, unpack FP16 values, softmax + weighted sum. Cache layout (no leading 2 dimension): (num_blocks, block_size, num_kv_heads, slot_size) where slot_size = key_packed_size + value_fp16_size Per-head per-position slot layout: [key_packed (kps bytes) | value_fp16 (D*2 bytes)] For turboquant_k3v4_nc head_dim=256: [100 bytes key | 512 bytes value] = 612 """ import functools import math from dataclasses import dataclass from typing import Any, ClassVar import torch import torch.nn.functional as F from vllm.config import get_current_vllm_config from vllm.config.cache import CacheDType from vllm.triton_utils import triton from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, AttentionImpl, AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, AttentionType, CommonAttentionMetadata, MultipleOf, ) from vllm.v1.attention.backends.fa_utils import ( is_flash_attn_varlen_func_available, ) from vllm.v1.attention.backends.utils import split_decodes_and_prefills from vllm.v1.attention.ops.triton_turboquant_decode import ( _tq_full_dequant_kv, _use_fp8_e4b15, triton_turboquant_decode_attention, ) from vllm.v1.attention.ops.triton_turboquant_store import triton_turboquant_store _HAS_FLASH_ATTN = is_flash_attn_varlen_func_available() if _HAS_FLASH_ATTN: from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func # Continuation prefill: for small continuation chunks (q_len ≤ threshold), # use the TQ decode kernel directly instead of full-dequant + flash_attn. # do_kv_cache_update already stored all tokens to TQ cache, so the decode # kernel can read them efficiently. This avoids O(cached_len) dequant work # per continuation, eliminating the O(N²/chunk_size) collapse at long context. _CONTINUATION_DECODE_THRESHOLD = 128 def _build_hadamard(d: int, device_str: str) -> torch.Tensor: """Orthonormal Hadamard matrix (Sylvester construction), cached per (d, device). Precomputed D×D matrix enables matmul-based WHT — single cuBLAS GEMM instead of log2(D) butterfly kernel launches. 64KB for D=128. """ # Normalize device string so "cuda" and "cuda:0" hit the same cache entry. return _build_hadamard_cached(d, str(torch.device(device_str))) @functools.cache def _build_hadamard_cached(d: int, device_str: str) -> torch.Tensor: H = torch.tensor([[1.0]]) while H.shape[0] < d: H = torch.cat([torch.cat([H, H], 1), torch.cat([H, -H], 1)], 0) return (H / math.sqrt(d)).to(torch.device(device_str)) class TurboQuantAttentionBackend(AttentionBackend): """Attention backend using TurboQuant KV-cache compression.""" accept_output_buffer: bool = True forward_includes_kv_cache_update: bool = False supported_dtypes: ClassVar[list[torch.dtype]] = [ torch.float16, torch.bfloat16, ] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "turboquant_k8v4", "turboquant_4bit_nc", "turboquant_k3v4_nc", "turboquant_3bit_nc", ] @staticmethod def get_name() -> str: return "TURBOQUANT" @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: return [16, 32, 64, 128] @classmethod def supports_attn_type(cls, attn_type: str) -> bool: return attn_type == AttentionType.DECODER @classmethod def supports_per_head_quant_scales(cls) -> bool: return False @staticmethod def get_impl_cls() -> type["TurboQuantAttentionImpl"]: return TurboQuantAttentionImpl @staticmethod def get_builder_cls() -> type["TurboQuantMetadataBuilder"]: return TurboQuantMetadataBuilder @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, cache_dtype_str: str = "turboquant_4bit_nc", ) -> tuple[int, ...]: """Combined K+V cache shape — no leading 2 dimension. Standard attention backends use (2, num_blocks, block_size, num_kv_heads, head_dim) with a leading 2 to separate K and V. TurboQuant packs K+V into a single interleaved slot per head per position, so the cache is: (num_blocks, block_size, num_kv_heads, slot_size_aligned) Each slot = [key_packed | value_packed | padding]. This is safe because TQ has its own get_kv_cache_shape override and never shares cache tensors with other backends. Layers that fall back to native dtype via kv_cache_dtype_skip_layers get their own standard-shaped cache allocation. head_size is the model's real head_dim. slot_size_aligned is computed from the TQ config to ensure correct cache allocation for all head dims. """ from vllm.model_executor.layers.quantization.turboquant.config import ( TurboQuantConfig, ) tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype_str, head_size) return (num_blocks, block_size, num_kv_heads, tq_config.slot_size_aligned) @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: if kv_cache_dtype is None: return False return kv_cache_dtype.startswith("turboquant_") @classmethod def supports_head_size(cls, head_size: int) -> bool: # head_size from spec is effective_head_size (padded_slot//2), # not the model's actual head_dim. Accept any positive value. return head_size > 0 @dataclass class TurboQuantMetadata(AttentionMetadata): """Metadata for TurboQuant attention.""" seq_lens: torch.Tensor # (num_reqs,) — total context length per request slot_mapping: torch.Tensor # (num_tokens,) — cache slot for each token block_table: torch.Tensor # (num_reqs, max_num_blocks) query_start_loc: torch.Tensor # (num_reqs + 1,) — cu_seqlens for queries num_actual_tokens: int = 0 # actual tokens (excluding padding) max_query_len: int = 0 # longest query in batch max_seq_len: int = 0 # longest context in batch is_prefill: bool = False num_decodes: int = 0 # number of decode requests (first in batch) num_decode_tokens: int = 0 # tokens from decode requests class TurboQuantMetadataBuilder(AttentionMetadataBuilder[TurboQuantMetadata]): """Builds TurboQuantMetadata from scheduler output.""" _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH def __init__(self, kv_cache_spec, layer_names, vllm_config, device): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self._init_reorder_batch_threshold(1, supports_spec_as_decode=False) def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TurboQuantMetadata: attn_metadata = self.build(0, common_attn_metadata) # Set seq_lens to 1 so CUDA graph capture is fast # (real seq_lens are filled at replay time). attn_metadata.seq_lens.fill_(1) return attn_metadata def build(self, common_prefix_len, common_attn_metadata, fast_build=False): """Build TurboQuantMetadata from common attention metadata.""" cam = common_attn_metadata # With reorder_batch_threshold=1, the model runner guarantees # decodes come first in the batch. split_decodes_and_prefills # finds the boundary (operates on CPU tensors — no GPU sync). assert self.reorder_batch_threshold is not None num_decodes, num_prefills, num_decode_tokens, _ = split_decodes_and_prefills( cam, decode_threshold=self.reorder_batch_threshold ) return TurboQuantMetadata( seq_lens=cam.seq_lens, slot_mapping=cam.slot_mapping, block_table=cam.block_table_tensor, query_start_loc=cam.query_start_loc, num_actual_tokens=cam.num_actual_tokens, max_query_len=cam.max_query_len, max_seq_len=cam.max_seq_len, is_prefill=(cam.max_query_len > 1), num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, ) class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): """TurboQuant attention implementation. Vectorized PyTorch: batch quantize/store, vectorized bit-unpack decode with einsum scores and value gather. """ supports_quant_query_input: bool = False def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int | None = None, alibi_slopes: list[float] | None = None, sliding_window: int | None = None, kv_cache_dtype: str = "auto", logits_soft_cap: float | None = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, **kwargs, ): self.num_heads = num_heads self.head_size = head_size self.scale = scale self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads self.num_kv_groups = num_heads // self.num_kv_heads self.kv_cache_dtype = kv_cache_dtype from vllm.model_executor.layers.quantization.turboquant.config import ( TurboQuantConfig, ) self.tq_config = TurboQuantConfig.from_cache_dtype(kv_cache_dtype, head_size) # Pre-compute kernel constants from config (avoid repeated arithmetic) cfg = self.tq_config self._mse_bytes = ( math.ceil(head_size * cfg.key_mse_bits / 8) if not cfg.key_fp8 else head_size ) self._val_data_bytes = math.ceil(head_size * cfg.effective_value_quant_bits / 8) self._n_centroids = cfg.n_centroids if not cfg.key_fp8 else 1 # Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph, # and benchmarks show no regression vs dynamic in eager mode). vllm_config = get_current_vllm_config() self.max_num_kv_splits = ( vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph ) def _ensure_on_device(self, layer, device): """One-time derivation of TQ buffers (rotation matrices, midpoints). Registered buffers (_tq_signs, _tq_centroids) are already on the correct device via register_buffer + model.to(device). """ if not hasattr(layer, "_tq_cached"): D = layer._tq_signs.shape[0] signs = layer._tq_signs.to(device=device, dtype=torch.float32) # WHT rotation: orthonormal + self-inverse, enabling future # in-kernel butterfly fusion and trivial inverse for continuation. H = _build_hadamard(D, str(device)) layer._tq_PiT = (signs.unsqueeze(1) * H).contiguous() layer._tq_Pi = layer._tq_PiT.T.contiguous() c = layer._tq_centroids.to(device=device, dtype=torch.float32) # Precompute midpoints for threshold-based quantization c_sorted, _ = c.sort() layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2 # Decode buffers (_tq_mid_o_buf, _tq_output_buf, _tq_lse_buf) # are pre-allocated via register_buffer in Attention.__init__ # and moved to GPU by model.to(device) — no allocation needed # here. The memory profiler sees them before KV cache sizing. layer._tq_cached = True def do_kv_cache_update( self, layer: torch.nn.Module, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, ) -> None: """Store compressed K/V into the combined TQ cache. Called as a separate custom op (unified_kv_cache_update) BEFORE the attention forward, matching FlashAttention's split pattern. slot_mapping is already sliced to num_actual_tokens by the caller. """ N = slot_mapping.shape[0] if N <= 0: return device = key.device self._ensure_on_device(layer, device) k = key[:N].view(N, self.num_kv_heads, self.head_size) v = value[:N].view(N, self.num_kv_heads, self.head_size) self._store_kv(k, v, kv_cache, slot_mapping, layer) def forward( self, layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: "TurboQuantMetadata", output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: num_tokens = query.shape[0] if output is None: output = torch.zeros( num_tokens, self.num_heads * self.head_size, dtype=query.dtype, device=query.device, ) if attn_metadata is None: return output.fill_(0) # Slice to actual tokens N = attn_metadata.num_actual_tokens if N <= 0: return output.fill_(0) q = query[:N].view(N, self.num_heads, self.head_size) # Get TQ buffers, ensure on device (one-time migration). # Use Any-typed alias for dynamic _tq_* attrs set by _ensure_on_device. tq_layer: Any = layer device = q.device self._ensure_on_device(tq_layer, device) Pi = tq_layer._tq_Pi PiT = tq_layer._tq_PiT centroids = tq_layer._tq_centroids # Compute attention (KV cache was already updated by do_kv_cache_update) # With reorder_batch_threshold=1, decodes come first in the batch. # num_decodes/num_decode_tokens from metadata give the split point. num_decodes = attn_metadata.num_decodes num_decode_tokens = attn_metadata.num_decode_tokens if not attn_metadata.is_prefill: # Pure decode batch — fast path attn_out = self._decode_attention( q, kv_cache, attn_metadata, Pi, centroids, PiT, layer ) elif num_decodes == 0: # Pure prefill batch k = key[:N].view(N, self.num_kv_heads, self.head_size) v = value[:N].view(N, self.num_kv_heads, self.head_size) attn_out = self._prefill_attention( q, k, v, kv_cache, attn_metadata, Pi, centroids, PiT, layer=layer, ) else: # Mixed batch: decodes first (guaranteed by reorder_batch). attn_out = torch.zeros( N, self.num_heads, self.head_size, device=device, dtype=q.dtype ) # --- Decode portion (first num_decodes requests) --- # Use full-batch max_seq_len as safe upper bound (no GPU sync). decode_meta = TurboQuantMetadata( seq_lens=attn_metadata.seq_lens[:num_decodes], slot_mapping=attn_metadata.slot_mapping[:num_decode_tokens], block_table=attn_metadata.block_table[:num_decodes], query_start_loc=attn_metadata.query_start_loc[: num_decodes + 1], num_actual_tokens=num_decode_tokens, max_query_len=1, max_seq_len=attn_metadata.max_seq_len, is_prefill=False, ) attn_out[:num_decode_tokens] = self._decode_attention( q[:num_decode_tokens], kv_cache, decode_meta, Pi, centroids, PiT, layer ) # --- Prefill portion (remaining requests) --- # CRITICAL: use prefill-specific max_seq_len so flash_attn's # fast path (max_query_len == max_seq_len) triggers for # first-chunk prefills. Using full-batch max_seq_len breaks # this because decode requests inflate max_seq_len. prefill_seq_lens = attn_metadata.seq_lens[num_decodes:] # Use CPU-side max to avoid GPU→CPU sync from .item() prefill_max_seq = max(attn_metadata.seq_lens[num_decodes:].tolist()) prefill_qsl = ( attn_metadata.query_start_loc[num_decodes:] - num_decode_tokens ) prefill_meta = TurboQuantMetadata( seq_lens=prefill_seq_lens, slot_mapping=attn_metadata.slot_mapping[num_decode_tokens:N], block_table=attn_metadata.block_table[num_decodes:], query_start_loc=prefill_qsl, num_actual_tokens=N - num_decode_tokens, max_query_len=attn_metadata.max_query_len, max_seq_len=prefill_max_seq, is_prefill=True, ) k = key[:N].view(N, self.num_kv_heads, self.head_size) v = value[:N].view(N, self.num_kv_heads, self.head_size) attn_out[num_decode_tokens:] = self._prefill_attention( q[num_decode_tokens:], k[num_decode_tokens:], v[num_decode_tokens:], kv_cache, prefill_meta, Pi, centroids, PiT, layer=layer, ) # Write into output buffer: attn_out is (N, Hq, D) # output may be 2D (N, Hq*D) or 3D (N, Hq, D) if output.ndim == 3: output[:N] = attn_out.to(output.dtype) else: output[:N] = attn_out.reshape(N, -1).to(output.dtype) return output # ------------------------------------------------------------------ # # Store K/V into combined cache (vectorized) # # ------------------------------------------------------------------ # def _store_kv( self, key: torch.Tensor, # (N, Hk, D) value: torch.Tensor, # (N, Hk, D) kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size) slot_mapping: torch.Tensor, layer: Any, ): """Quantize + store via fused Triton kernel.""" triton_turboquant_store( key, value, kv_cache, slot_mapping, layer._tq_PiT, layer._tq_midpoints, mse_bits=self.tq_config.key_mse_bits, key_packed_size=self.tq_config.key_packed_size, value_quant_bits=self.tq_config.effective_value_quant_bits, key_fp8=self.tq_config.key_fp8, ) # ------------------------------------------------------------------ # # Prefill: SDPA on raw Q/K/V with causal mask # # ------------------------------------------------------------------ # def _prefill_attention( self, query: torch.Tensor, # (N, Hq, D) key: torch.Tensor, # (N, Hk, D) value: torch.Tensor, # (N, Hk, D) kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size) attn_metadata: TurboQuantMetadata, Pi: torch.Tensor, centroids: torch.Tensor, PiT: torch.Tensor | None = None, layer: Any = None, ) -> torch.Tensor: N, Hq, D = query.shape # Fast path: use flash_attn for first-chunk prefills (all K/V in batch). # max_query_len == max_seq_len means no request has prior cached KV. # Both are Python ints — no GPU sync. if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len: output = torch.empty(N, Hq, D, device=query.device, dtype=query.dtype) flash_attn_varlen_func( q=query, k=key, v=value, cu_seqlens_q=attn_metadata.query_start_loc, cu_seqlens_k=attn_metadata.query_start_loc, max_seqlen_q=attn_metadata.max_query_len, max_seqlen_k=attn_metadata.max_query_len, softmax_scale=self.scale, causal=True, out=output, ) return output # Continuation or no flash_attn: per-request attention. # For continuation chunks (seq_len > q_len), we must attend to # previously cached K/V from the TQ cache, not just the current # chunk's raw K/V. Hk = key.shape[1] use_gqa = Hk < Hq query_start_loc = attn_metadata.query_start_loc num_reqs = query_start_loc.shape[0] - 1 output = torch.zeros(N, Hq, D, device=query.device, dtype=query.dtype) # Convert to Python lists once (single CPU-GPU sync) instead of # per-request .item() calls that each force a sync. qsl = query_start_loc.tolist() seq_lens_list = attn_metadata.seq_lens.tolist() # Pre-allocate cu_seqlens for single-request flash_attn calls # to avoid per-request host→device tensor creation. _cu_2 = torch.zeros(2, device=query.device, dtype=torch.int32) for i in range(num_reqs): q_start = qsl[i] q_end = qsl[i + 1] q_len = q_end - q_start if q_len <= 0: continue seq_len = seq_lens_list[i] q_seq = query[q_start:q_end] # (q_len, Hq, D) k_seq = key[q_start:q_end] # (q_len, Hk, D) v_seq = value[q_start:q_end] # (q_len, Hk, D) if q_len == seq_len: # First-chunk prefill: all K/V are in the current batch. if _HAS_FLASH_ATTN: out = torch.empty_like(q_seq) _cu_2[1] = q_len cu = _cu_2 flash_attn_varlen_func( q=q_seq, k=k_seq, v=v_seq, cu_seqlens_q=cu, cu_seqlens_k=cu, max_seqlen_q=q_len, max_seqlen_k=q_len, softmax_scale=self.scale, causal=True, out=out, ) else: q_t = q_seq.transpose(0, 1).contiguous() k_t = k_seq.transpose(0, 1).contiguous() v_t = v_seq.transpose(0, 1).contiguous() out = F.scaled_dot_product_attention( q_t, k_t, v_t, is_causal=True, scale=self.scale, enable_gqa=use_gqa, ).transpose(0, 1) output[q_start:q_end] = out.to(query.dtype) else: # Continuation chunk: tokens already stored to TQ cache # by do_kv_cache_update. Use decode kernel directly to # avoid O(cached_len) full-dequant per continuation. # For large continuations, fall back to _continuation_prefill. cached_len = seq_len - q_len if q_len <= _CONTINUATION_DECODE_THRESHOLD: # Fast path: treat each query as a decode request # with incremental seq_lens for causal masking. synth_seq_lens = torch.arange( cached_len + 1, seq_len + 1, device=query.device, dtype=attn_metadata.seq_lens.dtype, ) synth_bt = attn_metadata.block_table[i : i + 1].expand(q_len, -1) out = triton_turboquant_decode_attention( query=q_seq, kv_cache=kv_cache, block_table=synth_bt, seq_lens=synth_seq_lens, Pi=Pi, centroids=centroids, scale=self.scale, mse_bits=self.tq_config.key_mse_bits, key_packed_size=self.tq_config.key_packed_size, value_quant_bits=(self.tq_config.effective_value_quant_bits), key_fp8=self.tq_config.key_fp8, norm_correction=self.tq_config.norm_correction, PiT=PiT, ) else: # Large continuation: dequant cached K/V and use # flash_attn for better throughput. out = self._continuation_prefill( layer, q_seq, k_seq, v_seq, kv_cache, attn_metadata.block_table[i : i + 1], cached_len, seq_len, Pi, centroids, ) output[q_start:q_end] = out.to(query.dtype) return output def _continuation_prefill( self, layer: Any, query: torch.Tensor, # (q_len, Hq, D) key_chunk: torch.Tensor, # (q_len, Hk, D) val_chunk: torch.Tensor, # (q_len, Hk, D) kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size) block_table: torch.Tensor, # (1, max_num_blocks) cached_len: int, seq_len: int, Pi: torch.Tensor, centroids: torch.Tensor, ) -> torch.Tensor: """Handle continuation chunk by dequanting cached K/V from TQ cache. Dequants previously cached K/V, concatenates with the current chunk's raw K/V, then runs flash_attn with causal masking. """ q_len, Hq, D = query.shape Hk = key_chunk.shape[1] device = query.device block_size = kv_cache.shape[1] BLOCK_D = triton.next_power_of_2(D) mse_bytes = self._mse_bytes val_data_bytes = self._val_data_bytes # Dequant cached K/V from TQ cache # Allocate slightly over to align to block_size for the grid. # Reuse cached buffers to avoid per-call allocation (~16MB at 8K). alloc_len = math.ceil(cached_len / block_size) * block_size buf_shape = (1, Hk, alloc_len, D) k_buf = getattr(layer, "_tq_k_dequant_buf", None) if k_buf is None or k_buf.shape[2] < alloc_len: k_buf = torch.empty(buf_shape, dtype=torch.float16, device=device) v_buf = torch.empty(buf_shape, dtype=torch.float16, device=device) layer._tq_k_dequant_buf = k_buf layer._tq_v_dequant_buf = v_buf else: v_buf = layer._tq_v_dequant_buf k_cached = k_buf[:, :, :alloc_len, :].zero_() v_cached = v_buf[:, :, :alloc_len, :].zero_() grid = (alloc_len, 1 * Hk) _tq_full_dequant_kv[grid]( kv_cache, block_table, centroids, k_cached, v_cached, k_cached.stride(0), k_cached.stride(1), k_cached.stride(2), v_cached.stride(0), v_cached.stride(1), v_cached.stride(2), kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), block_table.stride(0), HEAD_DIM=D, BLOCK_SIZE=block_size, NUM_KV_HEADS=Hk, MSE_BYTES=mse_bytes, KPS=self.tq_config.key_packed_size, VQB=self.tq_config.effective_value_quant_bits, VAL_DATA_BYTES=val_data_bytes, MSE_BITS=self.tq_config.key_mse_bits, KEY_FP8=1 if self.tq_config.key_fp8 else 0, BLOCK_D=BLOCK_D, NORM_CORRECTION=1 if self.tq_config.norm_correction else 0, FP8_E4B15=_use_fp8_e4b15(device.index or 0), num_warps=4, ) # Inverse-rotate MSE keys back to original space if not self.tq_config.key_fp8: k_flat = k_cached[0, :, :cached_len, :].reshape(-1, D).float() k_flat = k_flat @ Pi k_cached_trim = ( k_flat.to(torch.float16).reshape(Hk, cached_len, D).transpose(0, 1) ) # (cached_len, Hk, D) else: k_cached_trim = ( k_cached[0, :, :cached_len, :].transpose(0, 1).contiguous() ) # (cached_len, Hk, D) v_cached_trim = ( v_cached[0, :, :cached_len, :].transpose(0, 1).contiguous() ) # (cached_len, Hk, D) # Concatenate cached + current chunk K/V (match query dtype) qdtype = query.dtype k_full = torch.cat([k_cached_trim.to(qdtype), key_chunk], dim=0) v_full = torch.cat([v_cached_trim.to(qdtype), val_chunk], dim=0) # Attention: q_len queries attending to seq_len K/V with causal mask if _HAS_FLASH_ATTN: output = torch.empty(q_len, Hq, D, device=device, dtype=query.dtype) cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32) cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32) flash_attn_varlen_func( q=query, k=k_full, v=v_full, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=q_len, max_seqlen_k=seq_len, softmax_scale=self.scale, causal=True, out=output, ) return output else: # SDPA fallback: expand KV for GQA, build causal mask q_t = query.transpose(0, 1).unsqueeze(0) # (1, Hq, q_len, D) k_t = k_full.transpose(0, 1).unsqueeze(0) # (1, Hk, seq_len, D) v_t = v_full.transpose(0, 1).unsqueeze(0) # (1, Hk, seq_len, D) # Build causal mask: query position p can attend to K position j # where j <= cached_len + p (p is 0-indexed within chunk) q_pos = torch.arange(q_len, device=device).unsqueeze(1) + cached_len k_pos = torch.arange(seq_len, device=device).unsqueeze(0) mask = k_pos <= q_pos # (q_len, seq_len) out = F.scaled_dot_product_attention( q_t, k_t, v_t, attn_mask=mask, scale=self.scale, enable_gqa=(Hk < Hq), ) # (1, Hq, q_len, D) return out[0].transpose(0, 1) # (q_len, Hq, D) # ------------------------------------------------------------------ # # Decode: Triton TQ decode attention # # ------------------------------------------------------------------ # def _decode_attention( self, query: torch.Tensor, # (B, Hq, D) kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size) attn_metadata: TurboQuantMetadata, Pi: torch.Tensor, centroids: torch.Tensor, PiT: torch.Tensor | None = None, layer: torch.nn.Module | None = None, ) -> torch.Tensor: # Grab cached decode buffers from the layer (lazily allocated). mid_o_buf = output_buf = lse_buf = None if layer is not None: mid_o_buf = getattr(layer, "_tq_mid_o_buf", None) output_buf = getattr(layer, "_tq_output_buf", None) lse_buf = getattr(layer, "_tq_lse_buf", None) result = triton_turboquant_decode_attention( query=query, kv_cache=kv_cache, block_table=attn_metadata.block_table, seq_lens=attn_metadata.seq_lens, Pi=Pi, centroids=centroids, scale=self.scale, mse_bits=self.tq_config.key_mse_bits, key_packed_size=self.tq_config.key_packed_size, value_quant_bits=self.tq_config.effective_value_quant_bits, key_fp8=self.tq_config.key_fp8, norm_correction=self.tq_config.norm_correction, PiT=PiT, mid_o_buf=mid_o_buf, output_buf=output_buf, lse_buf=lse_buf, buf_holder=layer, max_num_kv_splits=self.max_num_kv_splits, ) return result