metadata.py 2.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Optional

import torch

import vllm.envs as envs


@dataclass
class KVCompressionAttentionMetadata:
    """Per-batch KV compression metadata consumed by attention backends."""

    must_keep: Optional[torch.Tensor] = None
    topk_budget: Optional[torch.Tensor] = None
    topk_budget_max: Optional[int] = None

    prompt_end: Optional[torch.Tensor] = None
    prompt_lens: Optional[torch.Tensor] = None
    prompt_topk_keep: Optional[torch.Tensor] = None
    prompt_topk_keep_max: Optional[int] = None


def build_kv_compression_attn_metadata(
    *,
    runner: Any,
    num_reqs: int,
    num_actual_tokens: int,
) -> KVCompressionAttentionMetadata:
    """Build KV compression metadata for one attention step.

    This helper keeps backend code thin and centralizes the logic for selecting
    between per-step compaction (scheme 1/2) and prompt-end one-shot scoring
    (scheme 3).
    """
    meta = KVCompressionAttentionMetadata()
    if not envs.VLLM_ENABLE_KV_COMPRESSION:
        return meta

    # Scheme 1/2: compute compaction destinations every step.
    if getattr(runner, "kv_compression_needs_compaction", False):
        meta.must_keep = runner.kv_compression_must_keep[:num_actual_tokens]
        meta.topk_budget = runner.kv_compression_topk_budget[:num_reqs]
        # Avoid device->host sync by reading from the CPU staging buffer.
        if num_reqs > 0:
            meta.topk_budget_max = int(
                runner.kv_compression_topk_budget_np[:num_reqs].max())
        else:
            meta.topk_budget_max = 0
        return meta

    # Scheme 3: compute global prompt indices only on the last prefill chunk,
    # and perform the actual cache compaction before the first decode step.
    scheduler_config = getattr(runner, "scheduler_config", None)
    if scheduler_config is None or not getattr(scheduler_config,
                                              "chunked_prefill_enabled",
                                              False):
        return meta

    if num_reqs <= 0:
        return meta
    if not runner.kv_compression_prompt_end_np[:num_reqs].any():
        return meta

    meta.prompt_end = runner.kv_compression_prompt_end[:num_reqs]
    meta.prompt_lens = runner.kv_compression_prompt_lens[:num_reqs]
    meta.prompt_topk_keep = runner.kv_compression_prompt_topk_keep[:num_reqs]
    meta.prompt_topk_keep_max = int(
        getattr(runner, "kv_compression_prompt_topk_keep_max", 0) or 0)
    return meta