# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations from dataclasses import dataclass from typing import Any, Optional import torch import vllm.envs as envs @dataclass class KVCompressionAttentionMetadata: """Per-batch KV compression metadata consumed by attention backends.""" must_keep: Optional[torch.Tensor] = None topk_budget: Optional[torch.Tensor] = None topk_budget_max: Optional[int] = None prompt_end: Optional[torch.Tensor] = None prompt_lens: Optional[torch.Tensor] = None prompt_topk_keep: Optional[torch.Tensor] = None prompt_topk_keep_max: Optional[int] = None def build_kv_compression_attn_metadata( *, runner: Any, num_reqs: int, num_actual_tokens: int, ) -> KVCompressionAttentionMetadata: """Build KV compression metadata for one attention step. This helper keeps backend code thin and centralizes the logic for selecting between per-step compaction (scheme 1/2) and prompt-end one-shot scoring (scheme 3). """ meta = KVCompressionAttentionMetadata() if not envs.VLLM_ENABLE_KV_COMPRESSION: return meta # Scheme 1/2: compute compaction destinations every step. if getattr(runner, "kv_compression_needs_compaction", False): meta.must_keep = runner.kv_compression_must_keep[:num_actual_tokens] meta.topk_budget = runner.kv_compression_topk_budget[:num_reqs] # Avoid device->host sync by reading from the CPU staging buffer. if num_reqs > 0: meta.topk_budget_max = int( runner.kv_compression_topk_budget_np[:num_reqs].max()) else: meta.topk_budget_max = 0 return meta # Scheme 3: compute global prompt indices only on the last prefill chunk, # and perform the actual cache compaction before the first decode step. scheduler_config = getattr(runner, "scheduler_config", None) if scheduler_config is None or not getattr(scheduler_config, "enable_chunked_prefill", False): return meta if num_reqs <= 0: return meta if not runner.kv_compression_prompt_end_np[:num_reqs].any(): return meta meta.prompt_end = runner.kv_compression_prompt_end[:num_reqs] meta.prompt_lens = runner.kv_compression_prompt_lens[:num_reqs] meta.prompt_topk_keep = runner.kv_compression_prompt_topk_keep[:num_reqs] meta.prompt_topk_keep_max = int( getattr(runner, "kv_compression_prompt_topk_keep_max", 0) or 0) return meta