# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations from typing import Any, Optional import torch def init_kv_compression_runner_buffers( *, runner: Any, max_num_tokens: int, max_num_reqs: int, device: torch.device, pin_memory: bool, ) -> None: """Initialize per-runner buffers used by KV compression. This helper keeps `gpu_model_runner.py` focused on orchestration while preserving the existing attribute-based access patterns. """ # KV positions are decoupled from logical positions when KV compression is # enabled. Keep a separate buffer to avoid recomputing or overwriting the # logical `positions_np` (used for RoPE / token lookup). runner.kv_positions_cpu = torch.zeros( max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=pin_memory, ) runner.kv_positions_np = runner.kv_positions_cpu.numpy() # KV compression metadata buffers (used by the "topk" policy). # Per-token: whether this scheduled token must be kept in KV cache. runner.kv_compression_must_keep_cpu = torch.zeros( max_num_tokens, dtype=torch.bool, device="cpu", pin_memory=pin_memory, ) runner.kv_compression_must_keep_np = runner.kv_compression_must_keep_cpu.numpy() runner.kv_compression_must_keep = torch.zeros( max_num_tokens, dtype=torch.bool, device=device, ) # Per-request: how many additional prompt tokens to keep among # non-protected candidates (budget from env; selection uses scores). runner.kv_compression_topk_budget_cpu = torch.zeros( max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory, ) runner.kv_compression_topk_budget_np = runner.kv_compression_topk_budget_cpu.numpy() runner.kv_compression_topk_budget = torch.zeros( max_num_reqs, dtype=torch.int32, device=device, ) # Chunked-prefill prompt-end KV compression metadata (scheme 3). # Per-request: whether this step finishes the prompt and should compute # global prompt indices (score/topk) for a one-shot compaction. runner.kv_compression_prompt_end_cpu = torch.zeros( max_num_reqs, dtype=torch.bool, device="cpu", pin_memory=pin_memory, ) runner.kv_compression_prompt_end_np = runner.kv_compression_prompt_end_cpu.numpy() runner.kv_compression_prompt_end = torch.zeros( max_num_reqs, dtype=torch.bool, device=device, ) # Per-request: prompt length (tokens) and Top-K keep count among prompt # candidates (excluding protected prefix/suffix). runner.kv_compression_prompt_lens_cpu = torch.zeros( max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory, ) runner.kv_compression_prompt_lens_np = runner.kv_compression_prompt_lens_cpu.numpy() runner.kv_compression_prompt_lens = torch.zeros( max_num_reqs, dtype=torch.int32, device=device, ) runner.kv_compression_prompt_topk_keep_cpu = torch.zeros( max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory, ) runner.kv_compression_prompt_topk_keep_np = runner.kv_compression_prompt_topk_keep_cpu.numpy() runner.kv_compression_prompt_topk_keep = torch.zeros( max_num_reqs, dtype=torch.int32, device=device, ) runner.kv_compression_prompt_topk_keep_max = None # type: Optional[int]