# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Bridge vLLM paged KV layout to compactor Triton kernels. vLLM FlashAttention KV cache is shaped [num_blocks, block_size, num_kv_heads, head_dim]. Compactor kernels expect a flat buffer [CACHE_SIZE, head_dim] and a page table global_page_table[batch, kv_head, logical_page] -> physical_page_id where each physical page holds ``block_size`` consecutive rows belonging to that KV head only. When num_kv_heads == 1 (MQA), a vLLM block maps 1:1 to compactor rows: row_index = physical_block_id * block_size + offset_in_block. When ``num_kv_heads > 1``, we permute to head-major ``[num_kv_heads, num_blocks, block_size, head_dim]`` and flatten to ``[num_kv_heads * num_blocks * block_size, head_dim]`` so each KV head occupies a disjoint row range in the flat buffer. The page table is built so each logical compression page maps to ``global_row // PAGE_SIZE`` in that layout (see ``build_page_table_head_major``). """ from __future__ import annotations import torch def _cdiv(n: int, d: int) -> int: return (n + d - 1) // d def flatten_kv_cache_head_major( key_cache: torch.Tensor, value_cache: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """View ``[nb, bs, H, D]`` caches as ``[H*nb*bs, D]`` in head-major order.""" if key_cache.shape != value_cache.shape: raise ValueError("key_cache and value_cache must match") nb, bs, hkv, d = key_cache.shape k_hm = key_cache.permute(2, 0, 1, 3).contiguous() v_hm = value_cache.permute(2, 0, 1, 3).contiguous() k_flat = k_hm.reshape(hkv * nb * bs, d) v_flat = v_hm.reshape(hkv * nb * bs, d) return k_flat, v_flat def write_head_major_flat_to_interleaved( k_flat: torch.Tensor, v_flat: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, ) -> None: """Copy ``[H*nb*bs, D]`` head-major flats back to ``[nb, bs, H, D]``.""" nb, bs, hkv, d = key_cache.shape k_hm = k_flat.view(hkv, nb, bs, d) v_hm = v_flat.view(hkv, nb, bs, d) key_cache.copy_(k_hm.permute(1, 2, 0, 3)) value_cache.copy_(v_hm.permute(1, 2, 0, 3)) def build_page_table_head_major( block_table: torch.Tensor, num_kv_heads: int, num_blocks: int, block_size: int, page_size: int, max_batches: int, ) -> torch.Tensor: """Build ``[max_batches, H, max_chain]`` page table for head-major flat KV. Chains physical page ids in ``block_table`` order for each (batch, head). Each entry is ``global_row // page_size`` where ``global_row`` indexes rows in the head-major flat buffer (see ``flatten_kv_cache_head_major``). """ bsz, max_blocks = block_table.shape if bsz > max_batches: raise ValueError("batch size exceeds max_batches for page table") num_pages_per_block = _cdiv(block_size, page_size) max_chain = max_blocks * num_pages_per_block out = torch.zeros( (max_batches, num_kv_heads, max_chain), dtype=torch.int32, device=block_table.device, ) bt = block_table.to(torch.int64) for b in range(bsz): for h in range(num_kv_heads): lp_idx = 0 for blk_i in range(max_blocks): bid = int(bt[b, blk_i].item()) if bid < 0: continue if bid >= num_blocks: raise ValueError( f"block_table[{b},{blk_i}]={bid} out of range " f"num_blocks={num_blocks}" ) base_row = h * (num_blocks * block_size) + bid * block_size for p in range(num_pages_per_block): start_row = base_row + p * page_size if start_row >= base_row + block_size: break phys = start_row // page_size out[b, h, lp_idx] = int(phys) lp_idx += 1 return out def flatten_kv_cache_plane( key_cache: torch.Tensor, value_cache: torch.Tensor, num_kv_heads: int, ) -> tuple[torch.Tensor, torch.Tensor]: """View (num_blocks, block_size, HKV, D) caches as [num_blocks*block_size*HKV, D]. This matches compactor row indexing only when HKV == 1 (see module doc). """ if num_kv_heads != 1: raise ValueError( "flatten_kv_cache_plane requires num_kv_heads==1 for compactor layout" ) if key_cache.shape != value_cache.shape: raise ValueError("key_cache and value_cache must match") # [num_blocks, block_size, 1, D] -> [num_blocks * block_size, D] nb, bs, hkv, d = key_cache.shape if hkv != 1: raise ValueError("expected num_kv_heads==1") k_flat = key_cache.reshape(nb * bs, d) v_flat = value_cache.reshape(nb * bs, d) if not k_flat.is_contiguous(): k_flat = k_flat.contiguous() if not v_flat.is_contiguous(): v_flat = v_flat.contiguous() return k_flat, v_flat def block_table_to_global_page_table( block_table: torch.Tensor, num_kv_heads: int, max_batches: int, ) -> torch.Tensor: """Build [max_batches, HKV, num_logical_pages] int32 page table. For MQA, every KV head reuses the same physical block ids as vLLM's table. """ # block_table: [num_reqs_padded, max_num_blocks] bsz, max_lp = block_table.shape if bsz > max_batches: raise ValueError("batch size exceeds max_batches for page table") out = torch.zeros( (max_batches, num_kv_heads, max_lp), dtype=torch.int32, device=block_table.device, ) bt = block_table.to(torch.int32)[:bsz] if num_kv_heads == 1: out[:bsz, 0, :max_lp] = bt else: for h in range(num_kv_heads): out[:bsz, h, :max_lp] = bt return out def build_batch_mapping(num_reqs: int, device: torch.device) -> torch.Tensor: """Local batch index -> global batch row (identity).""" return torch.arange(num_reqs, dtype=torch.int32, device=device)