from abc import ABC, abstractmethod import os from typing import Optional import torch from vllm.kvprune.kv_cache.store_kv_cache import prefill_store_topk_kv class BaseCompressionMethod(ABC): """ Abstract interface for KV cache compression methods. A compression method is implemented as a pair of optional scoring phases that run before and after rotary position embedding (RoPE) is applied: 1. ``pre_rope_scoring`` operates on pre-RoPE Q/K. 2. ``post_rope_scoring`` operates on post-RoPE Q/K and can either: - refine / reweight the pre-RoPE scores, or - compute potentially position-aware. Concrete subclasses are expected to implement both static methods and return a single tensor of scores (or ``None`` if the phase is a no-op), which the caller can then feed into the shared “scores → top-k indices → KV extraction” pipeline. """ @staticmethod @abstractmethod def pre_rope_scoring( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context, ) -> Optional[torch.Tensor]: """ Compute per-token importance scores from pre-RoPE queries/keys. Args: :param q: Pre-RoPE query tensor. Shape ``[total_tokens, HQ, D]```. :param k: Pre-RoPE key tensor. Shape ``[total_tokens, HKV, D]```. :param v: Value tensor. Shape ``[total_tokens, HKV, D]``` :param context: vllm.kvprune.utils.context.Context object carrying additional metadata, such as batch mappings or temporary buffers Returns: :return Optional[torch.Tensor]: A tensor of scores (e.g. per-token, per-head importance values) to be passed to ``post_rope_scoring`` or directly into the top-k selection step. If this phase is a no-op, implementations should return ``None``. Shape ``[total_tokens, HKV]```. """ pass @staticmethod @abstractmethod def post_rope_scoring( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pre_rope_scores: Optional[torch.Tensor], context, ) -> Optional[torch.Tensor]: """ Compute or refine importance scores from post-RoPE queries/keys. This method is called after rotary embeddings have been applied. It can optionally use both the post-RoPE Q/K and any scores produced by ``pre_rope_scoring`` to produce final scores used for token selection. Common patterns include: * Using ``pre_rope_scores`` as a base signal and applying a position-aware correction. * Only computing scores that depend on absolute or relative positions. * Simply passing through ``pre_rope_scores`` unchanged. Args: :param q: Post-RoPE query tensor. Shape ``[total_tokens, HQ, D]```. :param k: Post-RoPE key tensor. Shape ``[total_tokens, HKV, D]```. :param pre_rope_scores: Optional scores returned by ``pre_rope_scoring``. May be ``None`` if the pre-RoPE phase returned None. :param v: Value tensor. Shape ``[total_tokens, HKV, D]``` :param context: vllm.kvprune.utils.context.Context object carrying additional metadata, such as batch mappings or temporary buffers Returns: :return Optional[torch.Tensor]: Final importance scores to be consumed by the compression pipeline (for top-k token selection). If this phase is a no-op, implementations may return ``pre_rope_scores``. If None is returned, no compression will be applied. """ pass class NoCompression(BaseCompressionMethod): """ Trivial compression method that disables KV cache compression. """ @staticmethod def pre_rope_scoring( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context ) -> Optional[torch.Tensor]: return None @staticmethod def post_rope_scoring( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pre_rope_scores: torch.Tensor, context, ) -> Optional[torch.Tensor]: return pre_rope_scores def extract_and_store_top_kv( scores: torch.Tensor, cu_seqlens_k: torch.Tensor, max_k_len: int, top_k: int, H: int, new_keys: torch.Tensor, # [N_total, H, D] new_vals: torch.Tensor, # [N_total, H, D] num_tokens_to_retain: torch.Tensor, # [B] int32 page_table: torch.Tensor, # [B_total, H, N_LOGICAL_PAGES_MAX] int32 batch_mapping: torch.Tensor, # [B] int32 (local -> true batch rows) bh_lens: torch.Tensor, # [B, H] int32 (contiguous), UPDATED atomically k_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D] v_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D] PAGE_SIZE: int, PAD_TO_PAGE_SIZE: bool = True, K_TILE: int = 16, padding: float = -float("inf"), ): """helper method to extract and store top-k indices into KV cache (so they can be executed in a single stream)""" assert num_tokens_to_retain is not None, "num_tokens_to_retain must be set" # per_head: per-head highest-scoring remaining tokens for page padding. # global_scan: legacy global ranking order, padded by scanning forward in-kernel. padding_mode = os.environ.get( "VLLM_KVPRUNE_PADDING_MODE", "per_head" ).strip().lower() max_pairs_per_batch = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).to( device=num_tokens_to_retain.device, dtype=num_tokens_to_retain.dtype ) * H num_tokens_to_retain = torch.minimum(num_tokens_to_retain, max_pairs_per_batch) indices_topk, candidate_counts = scores_to_retain_indices( scores, cu_seqlens_k=cu_seqlens_k, max_k_len=max_k_len, top_k=top_k, H=H, num_tokens_to_retain=num_tokens_to_retain, page_size=PAGE_SIZE, padding_mode=padding_mode, padding=padding, ) prefill_store_topk_kv( new_keys=new_keys, new_vals=new_vals, indices_topk=indices_topk, candidate_counts=candidate_counts, num_tokens_to_retain=num_tokens_to_retain, page_table=page_table, batch_mapping=batch_mapping, bh_lens=bh_lens, k_cache=k_cache, v_cache=v_cache, cu_seqlens_k=cu_seqlens_k, PAGE_SIZE=PAGE_SIZE, PAD_TO_PAGE_SIZE=PAD_TO_PAGE_SIZE, K_TILE=K_TILE, ) def scores_to_retain_indices( scores: torch.Tensor, cu_seqlens_k: torch.Tensor, max_k_len: int, top_k: int, H: int, num_tokens_to_retain: torch.Tensor, page_size: int, padding_mode: str = "per_head", padding: float = -float("inf"), ) -> tuple[torch.Tensor, torch.Tensor]: """ Build candidate token-head indices for compression writes. For each batch element, this helper returns: 1. a prefix of the true global top-k ``(token, head)`` pairs, and 2. a suffix of additional padding candidates according to ``padding_mode``: - ``per_head``: choose each head's highest-scoring remaining tokens. - ``global_scan``: keep the legacy global ranking order and let the store kernel scan forward until it finds enough entries for that head. The page-alignment requirement comes from the paged KV cache, but the padding candidates themselves do not need to be discovered inside the Triton store kernel. Choosing them here avoids the older "scan the global candidate list until you stumble across enough entries for this head" behavior, which could distort the retained set even though the page-table / reclaim logic only cares about the final per-head counts. Args: :param scores: Tensor of shape ``[N_total, HKV]`` containing scores for each (token, head) pair in packed varlen format. :param cu_seqlens_k: Tensor of shape ``[B + 1]`` (int32) with cumulative key sequence lengths for each batch element. The total number of tokens satisfies ``N_total = cu_seqlens_k[-1]``. :param max_k_len: Maximum key sequence length across the batch (i.e. ``max_b seqlen_k[b]``). Used to allocate the padded buffer. :param top_k: Kept for API compatibility with the caller. The retained prefix is determined by ``num_tokens_to_retain``; the tail is built from per-head padding needs. :param H: Number of key heads; must match ``scores.shape[1]``. :param num_tokens_to_retain: The true number of token-head pairs to keep for each batch element before page padding. :param page_size: Page size of the KV cache. Determines how many extra candidates are needed per head to reach page alignment. :param padding_mode: ``per_head`` for per-head optimal padding candidates, or ``global_scan`` for the legacy "scan the global ranking" behavior. :param padding: Kept for backward compatibility; no longer used. Returns: A tuple ``(indices, counts)`` where: - ``indices`` is ``[B, MAX_SEL]`` int64, containing global flattened ``token * H + head`` indices. - ``counts`` is ``[B]`` int32, the number of valid candidates for each batch row inside ``indices``. """ del max_k_len, top_k, padding B, device = cu_seqlens_k.numel() - 1, scores.device row_indices: list[torch.Tensor] = [] candidate_counts = torch.zeros(B, dtype=torch.int32, device=device) if padding_mode not in ("per_head", "global_scan"): raise ValueError( "Unsupported VLLM_KVPRUNE_PADDING_MODE. " f"Expected 'per_head' or 'global_scan', got {padding_mode!r}." ) for b in range(B): s = int(cu_seqlens_k[b].item()) e = int(cu_seqlens_k[b + 1].item()) seq_len = e - s total_pairs = seq_len * H keep = min(int(num_tokens_to_retain[b].item()), total_pairs) if total_pairs == 0 or keep == 0: row_indices.append(torch.empty(0, dtype=torch.int64, device=device)) continue seq_scores = scores[s:e, :] # [L, H] flat_scores = seq_scores.reshape(-1) if padding_mode == "global_scan": row = torch.argsort(flat_scores, dim=0, descending=True) else: prefix = torch.topk( flat_scores, k=keep, dim=0, largest=True, sorted=True ).indices selected_flat = torch.zeros(total_pairs, dtype=torch.bool, device=device) selected_flat[prefix] = True selected_mask = selected_flat.view(seq_len, H) head_counts = torch.bincount(prefix % H, minlength=H) need_per_head = (page_size - (head_counts % page_size)) % page_size max_extra_per_head = seq_len - head_counts need_per_head = torch.minimum(need_per_head, max_extra_per_head) tails: list[torch.Tensor] = [] for h in range(H): need = int(need_per_head[h].item()) if need <= 0: continue rem_scores_h = seq_scores[:, h].masked_fill( selected_mask[:, h], -torch.inf ) tail_tok = torch.topk( rem_scores_h, k=need, dim=0, largest=True, sorted=True ).indices tails.append(tail_tok * H + h) if tails: row = torch.cat([prefix, *tails], dim=0) else: row = prefix row_indices.append(row + s * H) candidate_counts[b] = int(row.numel()) max_sel = max((int(x.numel()) for x in row_indices), default=0) if max_sel == 0: return ( torch.zeros((B, 1), dtype=torch.int64, device=device), candidate_counts, ) indices = torch.zeros((B, max_sel), dtype=torch.int64, device=device) for b, row in enumerate(row_indices): if row.numel(): indices[b, : row.numel()] = row return indices, candidate_counts