from __future__ import annotations from typing import Optional, Tuple import torch from vllm.triton_utils import HAS_TRITON if HAS_TRITON: import triton import triton.language as tl def _require_triton() -> None: if not HAS_TRITON: raise RuntimeError("Triton is not available.") def _check_cuda(*tensors: torch.Tensor) -> None: for t in tensors: if not isinstance(t, torch.Tensor): raise TypeError("Expected torch.Tensor inputs.") if t.device.type != "cuda": raise RuntimeError("Triton KV cache ops require CUDA/ROCm tensors.") @triton.autotune( configs=[ triton.Config({'BLOCK_T': 128, 'BLOCK_D': 64}, num_warps=4, num_stages=2), triton.Config({'BLOCK_T': 256, 'BLOCK_D': 64}, num_warps=4, num_stages=2), triton.Config({'BLOCK_T': 256, 'BLOCK_D': 128}, num_warps=8, num_stages=2), ], key=["D"], ) @triton.jit def _gather_k_to_packed_kernel( K_ptr, out_ptr, blk_ids_ptr, req_blk_starts_ptr, cu_seqlens_ptr, seq_lens_ptr, B, H, max_blocks, block_size, D, sKb, sKh, sKt, sKd, so_t, so_h, so_d, BLOCK_T: tl.constexpr, BLOCK_D: tl.constexpr, ): pid_bh = tl.program_id(0) pid_t = tl.program_id(1) pid_d = tl.program_id(2) b = pid_bh // H h = pid_bh % H if b >= B: return seq_len = tl.load(seq_lens_ptr + b) if seq_len <= 0: return t0 = pid_t * BLOCK_T t_range = t0 + tl.arange(0, BLOCK_T) t_mask = t_range < seq_len d0 = pid_d * BLOCK_D d_range = d0 + tl.arange(0, BLOCK_D) d_mask = d_range < D # Map logical token indices -> physical block ids. blk = t_range // block_size inb = t_range - blk * block_size req_blk_start = tl.load(req_blk_starts_ptr + b) gblk = req_blk_start + blk # Guard against out-of-range block indices (should not happen when block_table # covers the sequence length). gblk_safe = tl.where(t_mask, gblk, 0) bid = tl.load(blk_ids_ptr + gblk_safe, mask=t_mask, other=0) # Source: key cache layout [num_blocks, H, block_size, D] src_base = K_ptr + bid[:, None] * sKb + h * sKh + inb[:, None] * sKt src_ptrs = src_base + d_range[None, :] * sKd # Destination: packed output layout [T, H, D] out_start = tl.load(cu_seqlens_ptr + b) dst_base = out_ptr + (out_start + t_range)[:, None] * so_t + h * so_h dst_ptrs = dst_base + d_range[None, :] * so_d tile = tl.load(src_ptrs, mask=(t_mask[:, None] & d_mask[None, :]), other=0) tl.store(dst_ptrs, tile, mask=(t_mask[:, None] & d_mask[None, :])) @torch.inference_mode() def gather_k_to_packed_triton( key_cache: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor, cu_seqlens: torch.Tensor, *, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Gather a block-wise KV key cache into a packed [T, H, D] tensor. Expected layouts: - key_cache: [num_blocks, H, block_size, D] - block_table: [B, max_blocks] int32 physical block ids - seq_lens: [B] int32 logical lengths (tokens) to gather - cu_seqlens: [B+1] int32 cumulative offsets into the packed output """ _require_triton() _check_cuda(key_cache, block_table, seq_lens, cu_seqlens) if key_cache.ndim != 4: raise ValueError("key_cache must be a 4D tensor [num_blocks, H, Tb, D].") if block_table.ndim != 2: raise ValueError("block_table must be 2D [B, max_blocks].") if seq_lens.ndim != 1: raise ValueError("seq_lens must be 1D [B].") if cu_seqlens.ndim != 1: raise ValueError("cu_seqlens must be 1D [B+1].") device = key_cache.device B = int(seq_lens.numel()) if B == 0: return torch.empty((0, int(key_cache.shape[1]), int(key_cache.shape[3])), device=device, dtype=key_cache.dtype) H = int(key_cache.shape[1]) block_size = int(key_cache.shape[2]) D = int(key_cache.shape[3]) max_blocks = int(block_table.shape[1]) seq_lens_i32 = seq_lens.to(device=device, dtype=torch.int32) cu_i32 = cu_seqlens.to(device=device, dtype=torch.int32) total_tokens = int(cu_i32[-1].item()) if cu_i32.numel() > 0 else 0 if out is None: out = torch.empty((total_tokens, H, D), device=device, dtype=key_cache.dtype) else: if out.shape != (total_tokens, H, D): raise ValueError( f"out has shape {tuple(out.shape)}, expected {(total_tokens, H, D)}." ) blk_ids = block_table.to(device=device, dtype=torch.int32).reshape(-1) req_starts = (torch.arange(B, device=device, dtype=torch.int32) * max_blocks) sKb, sKh, sKt, sKd = [int(s) for s in key_cache.stride()] so_t, so_h, so_d = [int(s) for s in out.stride()] L_max = int(seq_lens_i32.max().item()) if B > 0 else 0 if total_tokens == 0 or L_max == 0 or D == 0 or H == 0: return out # Use the smallest tile sizes across autotune configs to guarantee coverage # even when the selected config uses smaller blocks. grid = ( B * H, triton.cdiv(L_max, 128), triton.cdiv(D, 64), ) _gather_k_to_packed_kernel[grid]( key_cache, out, blk_ids, req_starts, cu_i32, seq_lens_i32, B, H, max_blocks, block_size, D, sKb, sKh, sKt, sKd, so_t, so_h, so_d, ) return out @triton.autotune( configs=[ triton.Config({'BLOCK_T': 128, 'BLOCK_D': 64}, num_warps=4, num_stages=2), triton.Config({'BLOCK_T': 256, 'BLOCK_D': 64}, num_warps=4, num_stages=2), triton.Config({'BLOCK_T': 512, 'BLOCK_D': 64}, num_warps=8, num_stages=2), triton.Config({'BLOCK_T': 256, 'BLOCK_D': 128}, num_warps=8, num_stages=2), ], key=['K_max', 'Dk'], ) @triton.jit def _front_compact_inplace_fa_k_kernel( K_ptr, blk_ids_ptr, req_blk_starts_ptr, idx_ptr, keep_ptr, B, H, K_max, block_size, Dk, sKb, sKh, sKt, sKd, si_b, si_h, si_k, BLOCK_T: tl.constexpr, BLOCK_D: tl.constexpr, ): pid_bh = tl.program_id(0) pid_d = tl.program_id(1) b = pid_bh // H h = pid_bh % H if b >= B: return d0 = pid_d * BLOCK_D d_range = d0 + tl.arange(0, BLOCK_D) d_mask = d_range < Dk d_safe = tl.where(d_mask, d_range, 0) keep_b = tl.load(keep_ptr + b) if keep_b <= 0: return req_blk_start = tl.load(req_blk_starts_ptr + b) k0 = 0 while k0 < keep_b: k_range = k0 + tl.arange(0, BLOCK_T) k_mask = (k_range < K_max) & (k_range < keep_b) k_safe = tl.where(k_mask, k_range, 0) idx_base = idx_ptr + b * si_b + h * si_h + k_safe * si_k t_src = tl.load(idx_base, mask=k_mask, other=0) # No-op copies (src == dst) can be skipped safely because idx_sorted is # ascending, so we always copy from later/equal positions to earlier. t_dst = k_safe copy_mask = k_mask & (t_src != t_dst) blk_src = t_src // block_size inb_src = t_src % block_size gblk_src = req_blk_start + blk_src bid_src = tl.load(blk_ids_ptr + gblk_src, mask=copy_mask, other=0) blk_dst = t_dst // block_size inb_dst = t_dst % block_size gblk_dst = req_blk_start + blk_dst bid_dst = tl.load(blk_ids_ptr + gblk_dst, mask=copy_mask, other=0) src_base = K_ptr + bid_src[:, None] * sKb + h * sKh + inb_src[:, None] * sKt src_ptrs = src_base + d_safe[None, :] * sKd dst_base = K_ptr + bid_dst[:, None] * sKb + h * sKh + inb_dst[:, None] * sKt dst_ptrs = dst_base + d_safe[None, :] * sKd tile = tl.load(src_ptrs, mask=(copy_mask[:, None] & d_mask[None, :]), other=0) tl.store(dst_ptrs, tile, mask=(copy_mask[:, None] & d_mask[None, :])) k0 += BLOCK_T @triton.autotune( configs=[ triton.Config({'BLOCK_T': 128, 'BLOCK_D': 64}, num_warps=4, num_stages=2), triton.Config({'BLOCK_T': 256, 'BLOCK_D': 64}, num_warps=4, num_stages=2), triton.Config({'BLOCK_T': 512, 'BLOCK_D': 64}, num_warps=8, num_stages=2), triton.Config({'BLOCK_T': 256, 'BLOCK_D': 128}, num_warps=8, num_stages=2), ], key=['K_max', 'Dv'], ) @triton.jit def _front_compact_inplace_fa_v_kernel( V_ptr, blk_ids_ptr, req_blk_starts_ptr, idx_ptr, keep_ptr, B, H, K_max, block_size, Dv, sv_b, sv_h, sv_d, sv_t, si_b, si_h, si_k, BLOCK_T: tl.constexpr, BLOCK_D: tl.constexpr, ): pid_bh = tl.program_id(0) pid_d = tl.program_id(1) b = pid_bh // H h = pid_bh % H if b >= B: return d0 = pid_d * BLOCK_D d_range = d0 + tl.arange(0, BLOCK_D) d_mask = d_range < Dv d_safe = tl.where(d_mask, d_range, 0) keep_b = tl.load(keep_ptr + b) if keep_b <= 0: return req_blk_start = tl.load(req_blk_starts_ptr + b) k0 = 0 while k0 < keep_b: k_range = k0 + tl.arange(0, BLOCK_T) k_mask = (k_range < K_max) & (k_range < keep_b) k_safe = tl.where(k_mask, k_range, 0) idx_base = idx_ptr + b * si_b + h * si_h + k_safe * si_k t_src = tl.load(idx_base, mask=k_mask, other=0) t_dst = k_safe copy_mask = k_mask & (t_src != t_dst) blk_src = t_src // block_size inb_src = t_src % block_size gblk_src = req_blk_start + blk_src bid_src = tl.load(blk_ids_ptr + gblk_src, mask=copy_mask, other=0) blk_dst = t_dst // block_size inb_dst = t_dst % block_size gblk_dst = req_blk_start + blk_dst bid_dst = tl.load(blk_ids_ptr + gblk_dst, mask=copy_mask, other=0) # value layout: [num_blocks, H, Dv, block_size] v_src_base = V_ptr + bid_src[:, None] * sv_b + h * sv_h + d_safe[None, :] * sv_d v_src_ptrs = v_src_base + inb_src[:, None] * sv_t v_dst_base = V_ptr + bid_dst[:, None] * sv_b + h * sv_h + d_safe[None, :] * sv_d v_dst_ptrs = v_dst_base + inb_dst[:, None] * sv_t tile = tl.load(v_src_ptrs, mask=(copy_mask[:, None] & d_mask[None, :]), other=0) tl.store(v_dst_ptrs, tile, mask=(copy_mask[:, None] & d_mask[None, :])) k0 += BLOCK_T @torch.inference_mode() def front_compact_inplace_fa_triton( key_cache: torch.Tensor, value_cache: torch.Tensor, block_table: torch.Tensor, idx_sorted: torch.Tensor, keep: torch.Tensor, ) -> None: """In-place front compaction for FlashAttention KV cache. Moves selected time indices to the front [0..keep[b]) per request for both key_cache and value_cache in-place. Expected layouts: - key_cache: [num_blocks, H, block_size, Dk] - value_cache: [num_blocks, H, Dv, block_size] - block_table: [B, max_blocks] int32 physical block ids - idx_sorted: [B, K] int32 or [B, H, K] int32 (ascending indices) - keep: [B] int32 (<= K), number of kept tokens per request """ _require_triton() _check_cuda(key_cache, value_cache, block_table, idx_sorted, keep) if key_cache.ndim != 4 or value_cache.ndim != 4: raise ValueError("key_cache/value_cache must be 4D tensors.") if block_table.ndim != 2: raise ValueError("block_table must be 2D [B, max_blocks].") if idx_sorted.ndim not in (2, 3): raise ValueError("idx_sorted must be 2D [B,K] or 3D [B,H,K].") if keep.ndim != 1: raise ValueError("keep must be 1D [B].") device = key_cache.device B = int(block_table.shape[0]) if B == 0: return H = int(key_cache.shape[1]) block_size = int(key_cache.shape[2]) Dk = int(key_cache.shape[3]) Dv = int(value_cache.shape[2]) if idx_sorted.ndim == 2: idx_sorted = idx_sorted[:, None, :].expand(-1, H, -1) K_max = int(idx_sorted.shape[2]) if K_max == 0: return blk_ids = block_table.to(device=device, dtype=torch.int32).reshape(-1) max_blocks = int(block_table.shape[1]) req_starts = (torch.arange(B, device=device, dtype=torch.int32) * max_blocks) idx_i32 = idx_sorted.to(device=device, dtype=torch.int32) keep_i32 = keep.to(device=device, dtype=torch.int32) sKb, sKh, sKt, sKd = [int(s) for s in key_cache.stride()] sv_b, sv_h, sv_d, sv_t = [int(s) for s in value_cache.stride()] si_b, si_h, si_k = [int(s) for s in idx_i32.stride()] if Dk > 0: grid_k = ( B * H, triton.cdiv(Dk, 64), ) _front_compact_inplace_fa_k_kernel[grid_k]( key_cache, blk_ids, req_starts, idx_i32, keep_i32, B, H, K_max, block_size, Dk, sKb, sKh, sKt, sKd, si_b, si_h, si_k, ) if Dv > 0: grid_v = ( B * H, triton.cdiv(Dv, 64), ) _front_compact_inplace_fa_v_kernel[grid_v]( value_cache, blk_ids, req_starts, idx_i32, keep_i32, B, H, K_max, block_size, Dv, sv_b, sv_h, sv_d, sv_t, si_b, si_h, si_k, ) def make_fa_cache_view( *, key_cache: torch.Tensor, value_cache: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return (K_view, V_view) in the canonical FA compaction layout. - K_view: [num_blocks, H, block_size, D] - V_view: [num_blocks, H, D, block_size] """ if key_cache.ndim != 4 or value_cache.ndim != 4: raise ValueError("key_cache/value_cache must be 4D tensors.") # ROCm path (FlashAttention v1): K=[B,H,T,D] and V=[B,H,D,T] if (value_cache.shape[3] == key_cache.shape[2] and value_cache.shape[2] == key_cache.shape[3]): k_view = key_cache v_view = value_cache else: # CUDA path: K=[B,T,H,D] and V=[B,T,H,D] k_view = key_cache.permute(0, 2, 1, 3) v_view = value_cache.permute(0, 2, 3, 1) return k_view, v_view