# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations from typing import Optional import numpy as np import vllm.envs as envs from vllm.v1.kv_compression.budget import (compute_prompt_topk_keep_total, compute_topk_budget_step) def prepare_kv_compression_for_step( *, num_reqs: int, total_num_scheduled_tokens: int, num_scheduled_tokens: np.ndarray, # [B] int32 cu_num_tokens: np.ndarray, # [B] int64/int32 cumulative scheduled tokens req_indices: np.ndarray, # [T] int64, request index per token arange: np.ndarray, # [T] int64, position-within-request per token num_computed_tokens_cpu: np.ndarray, # [max_reqs] int32/int64 num_prompt_tokens: np.ndarray, # [max_reqs] int32/int64 num_kv_tokens_cpu: np.ndarray, # [max_reqs] int32/int64 kv_positions_np: np.ndarray, # [T] int64 (out) must_keep_np: np.ndarray, # [T] bool (out; scheme 1/2 only) topk_budget_np: np.ndarray, # [B] int32 (out; scheme 1/2 only) prompt_end_np: np.ndarray, # [B] bool (out; scheme 3 only) prompt_lens_np: np.ndarray, # [B] int32 (out; scheme 3 only) prompt_topk_keep_np: np.ndarray, # [B] int32 (out; scheme 3 only) chunked_prefill_enabled: bool, ) -> tuple[bool, Optional[int]]: """Prepare KV compression metadata for a single model step (CPU-side). Fills: - `kv_positions_np`: per-token KV write positions (decoupled from logical RoPE positions). - Scheme 3 (chunked prefill): `prompt_end/prompt_lens/prompt_topk_keep`. - Scheme 1/2 (non-chunked): `must_keep/topk_budget`. Returns: (needs_compaction, prompt_topk_keep_max) """ if total_num_scheduled_tokens <= 0 or num_reqs <= 0: return False, None # KV positions (where scheduled tokens are written before optional # compaction). np.add(num_kv_tokens_cpu[req_indices], arange, out=kv_positions_np) prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN if chunked_prefill_enabled: # Scheme 3: with chunked prefill, defer compaction until after the full # prompt is ingested. Otherwise, the next prefill chunk would attend to # a truncated history and quality can collapse. prompt_end_np.fill(False) prompt_lens_np.fill(0) prompt_topk_keep_np.fill(0) for req_idx in range(num_reqs): qlen = int(num_scheduled_tokens[req_idx]) if qlen <= 0: continue base_pos = int(num_computed_tokens_cpu[req_idx]) prompt_len = int(num_prompt_tokens[req_idx]) end_pos = base_pos + qlen ends_prompt = (base_pos < prompt_len) and (end_pos >= prompt_len) if not ends_prompt: continue prompt_end_np[req_idx] = True prompt_lens_np[req_idx] = prompt_len prompt_topk_keep_np[req_idx] = compute_prompt_topk_keep_total( prompt_len=prompt_len, protected_prefix=protected_prefix, protected_suffix=protected_suffix, keep_last_token=keep_last, prompt_ratio=prompt_ratio, prompt_budget=prompt_budget, ) prompt_topk_keep_max = int(prompt_topk_keep_np[:num_reqs].max()) return False, prompt_topk_keep_max # Scheme 1/2: per-step compaction within the scheduled segment. must_keep_np.fill(False) topk_budget_np.fill(0) for req_idx in range(num_reqs): qlen = int(num_scheduled_tokens[req_idx]) if qlen <= 0: continue start = 0 if req_idx == 0 else int(cu_num_tokens[req_idx - 1]) end = int(cu_num_tokens[req_idx]) assert end - start == qlen base_pos = int(num_computed_tokens_cpu[req_idx]) prompt_len = int(num_prompt_tokens[req_idx]) end_pos = base_pos + qlen pos_in_req = arange[start:end].astype(np.int64, copy=False) pos = base_pos + pos_in_req prompt_mask = pos < prompt_len # Decode tokens are always kept. must_keep = ~prompt_mask if np.any(prompt_mask): suffix_start = max(prompt_len - protected_suffix, 0) must_keep |= prompt_mask & (pos < protected_prefix) must_keep |= prompt_mask & (pos >= suffix_start) if keep_last: last = prompt_len - 1 if base_pos <= last < end_pos: must_keep[last - base_pos] = True topk_budget_np[req_idx] = compute_topk_budget_step( prompt_len=prompt_len, start_pos=base_pos, end_pos=end_pos, protected_prefix=protected_prefix, protected_suffix=protected_suffix, keep_last_token=keep_last, prompt_ratio=prompt_ratio, prompt_budget=prompt_budget, ) must_keep_np[start:end] = must_keep # Decode-only fast path: if all scheduled tokens are unconditionally kept # and there is no Top-K budget, KV compaction is a no-op and can be skipped. needs_compaction = (not must_keep_np.all()) or (topk_budget_np > 0).any() return bool(needs_compaction), None