# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations import torch import vllm.envs as envs from vllm.v1.kv_compression.slot_mapping import topk_kv_compact_slot_mapping from vllm.v1.kv_compression.snapkv_score import snapkv_like_token_scores def snapkv_window_for_topk_budget( *, topk_budget: torch.Tensor, # [B] int32 window: int, ) -> torch.Tensor: """Build per-request SnapKV window sizes for mixed batches. Requests with a zero Top-K budget do not need token scores; setting their window to 0 lets the Triton scoring kernel early-return. """ return torch.where( topk_budget > 0, torch.full_like(topk_budget, int(window)), torch.zeros_like(topk_budget), ) def compute_compact_dst_slots_for_step( *, query: torch.Tensor, # [T, Hq, D] for this step key: torch.Tensor, # [T, Hkv, D] for this step query_start_loc: torch.Tensor, # [B+1] seq_lens: torch.Tensor, # [B] int32 block_table: torch.Tensor, # [B, max_blocks] block_size: int, must_keep: torch.Tensor, # [T] bool topk_budget: torch.Tensor, # [B] int32 topk_budget_max: int, max_query_len: int, sm_scale: float, ) -> torch.Tensor: """Compute per-token KV compaction destinations for one step.""" token_scores = None if int(topk_budget_max) > 0: w = snapkv_window_for_topk_budget( topk_budget=topk_budget, window=int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW), ) token_scores = snapkv_like_token_scores( query=query, key=key, query_start_loc=query_start_loc, window=w, sm_scale=float(sm_scale), ) return topk_kv_compact_slot_mapping( token_scores=token_scores, must_keep=must_keep, topk_budget=topk_budget, query_start_loc=query_start_loc, seq_lens=seq_lens, block_table=block_table, block_size=int(block_size), max_query_len=int(max_query_len), topk_budget_max=int(topk_budget_max), )