# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations from typing import Optional import torch from vllm.v1.kv_compression.topk_select import (_packed_varlen_coords, _topk_keep_mask_and_local_rank) def _dst_slots_from_keep_mask_and_local_rank( *, keep_mask: torch.Tensor, # [T] bool local_rank: torch.Tensor, # [T] int64 seq_lens: torch.Tensor, # [B] int32 lengths: torch.Tensor, # [B] int64 req_ids: torch.Tensor, # [T] int64 block_table: torch.Tensor, # [B, max_blocks] int32 block_size: int, ) -> torch.Tensor: """Convert keep_mask/local_rank into a per-token KV destination slot.""" device = keep_mask.device T = int(keep_mask.numel()) dst_slots = torch.full((T, ), -1, device=device, dtype=torch.int64) if T == 0: return dst_slots B = int(seq_lens.numel()) if B == 0: return dst_slots # Base KV cache position for this step (i.e., KV length before writing this # scheduled segment). With KV compression enabled, seq_lens is derived from # num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len. base_kv = (seq_lens[:B].to(torch.long) - lengths.to(torch.long)).clamp_min(0) base_kv_per_token = base_kv.index_select(0, req_ids) # [T] dest_pos = base_kv_per_token + local_rank # [T] dest_block_idx = dest_pos // block_size dest_off = dest_pos - dest_block_idx * block_size # Safe indexing for dropped tokens (ignored by keep_mask anyway). max_blocks = int(block_table.shape[1]) dest_block_idx_safe = dest_block_idx.clamp_(0, max_blocks - 1).to(torch.long) block_nums = block_table[req_ids, dest_block_idx_safe] dest_slot = block_nums.to(torch.long) * block_size + dest_off return torch.where(keep_mask, dest_slot.to(torch.int64), dst_slots) def topk_kv_compact_slot_mapping( *, token_scores: Optional[torch.Tensor], # [T] float32 must_keep: torch.Tensor, # [T] bool topk_budget: torch.Tensor, # [B] int32 query_start_loc: torch.Tensor, # [B+1] seq_lens: torch.Tensor, # [B] int32 block_table: torch.Tensor, # [B, max_blocks] block_size: int, max_query_len: Optional[int] = None, topk_budget_max: Optional[int] = None, ) -> torch.Tensor: """Build a per-token destination slot mapping for KV compaction. Returns a tensor `dst_slots` of shape [T] where: - `dst_slots[i] >= 0` indicates token i should be kept and rewritten to that KV cache slot. - `dst_slots[i] == -1` indicates token i is dropped after the step. """ device = must_keep.device T = int(must_keep.numel()) B = int(topk_budget.numel()) dst_slots = torch.full((T, ), -1, device=device, dtype=torch.int64) if T == 0 or B == 0: return dst_slots starts, _, lengths, req_ids, pos_in_req = _packed_varlen_coords( cu_seqlens=query_start_loc, total_tokens=T, ) if lengths.numel() == 0: return dst_slots # Prefer the CPU-known max query length (piecewise graph), to avoid # device->host synchronization. L_max = int(max_query_len) if max_query_len is not None else int( lengths.max().item()) if L_max <= 0: return dst_slots keep_mask, local_rank, _ = _topk_keep_mask_and_local_rank( token_scores=token_scores, must_keep=must_keep, topk_budget=topk_budget, starts=starts, lengths=lengths, req_ids=req_ids, pos_in_req=pos_in_req, max_len=L_max, topk_budget_max=topk_budget_max, ) return _dst_slots_from_keep_mask_and_local_rank( keep_mask=keep_mask, local_rank=local_rank, seq_lens=seq_lens[:B], lengths=lengths, req_ids=req_ids, block_table=block_table, block_size=int(block_size), ) def kv_compaction_dst_rewrite_mapping( *, dst_slots: torch.Tensor, # [T] int64 src_slots: torch.Tensor, # [T] int64 ) -> torch.Tensor: """Filter a dst slot mapping so only moved kept tokens are rewritten. Non-rewrite tokens are marked as -1, which the cache kernels treat as padding and skip. """ rewrite_mask = (dst_slots >= 0) & (dst_slots != src_slots) return torch.where(rewrite_mask, dst_slots, -1)