runner_buffers.py 3.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

from typing import Any, Optional

import torch


def init_kv_compression_runner_buffers(
    *,
    runner: Any,
    max_num_tokens: int,
    max_num_reqs: int,
    device: torch.device,
    pin_memory: bool,
) -> None:
    """Initialize per-runner buffers used by KV compression.

    This helper keeps `gpu_model_runner.py` focused on orchestration while
    preserving the existing attribute-based access patterns.
    """
    # KV positions are decoupled from logical positions when KV compression is
    # enabled. Keep a separate buffer to avoid recomputing or overwriting the
    # logical `positions_np` (used for RoPE / token lookup).
    runner.kv_positions_cpu = torch.zeros(
        max_num_tokens,
        dtype=torch.int64,
        device="cpu",
        pin_memory=pin_memory,
    )
    runner.kv_positions_np = runner.kv_positions_cpu.numpy()

    # KV compression metadata buffers (used by the "topk" policy).
    # Per-token: whether this scheduled token must be kept in KV cache.
    runner.kv_compression_must_keep_cpu = torch.zeros(
        max_num_tokens,
        dtype=torch.bool,
        device="cpu",
        pin_memory=pin_memory,
    )
    runner.kv_compression_must_keep_np = runner.kv_compression_must_keep_cpu.numpy()
    runner.kv_compression_must_keep = torch.zeros(
        max_num_tokens,
        dtype=torch.bool,
        device=device,
    )

    # Per-request: how many additional prompt tokens to keep among
    # non-protected candidates (budget from env; selection uses scores).
    runner.kv_compression_topk_budget_cpu = torch.zeros(
        max_num_reqs,
        dtype=torch.int32,
        device="cpu",
        pin_memory=pin_memory,
    )
    runner.kv_compression_topk_budget_np = runner.kv_compression_topk_budget_cpu.numpy()
    runner.kv_compression_topk_budget = torch.zeros(
        max_num_reqs,
        dtype=torch.int32,
        device=device,
    )

    # Chunked-prefill prompt-end KV compression metadata (scheme 3).
    # Per-request: whether this step finishes the prompt and should compute
    # global prompt indices (score/topk) for a one-shot compaction.
    runner.kv_compression_prompt_end_cpu = torch.zeros(
        max_num_reqs,
        dtype=torch.bool,
        device="cpu",
        pin_memory=pin_memory,
    )
    runner.kv_compression_prompt_end_np = runner.kv_compression_prompt_end_cpu.numpy()
    runner.kv_compression_prompt_end = torch.zeros(
        max_num_reqs,
        dtype=torch.bool,
        device=device,
    )

    # Per-request: prompt length (tokens) and Top-K keep count among prompt
    # candidates (excluding protected prefix/suffix).
    runner.kv_compression_prompt_lens_cpu = torch.zeros(
        max_num_reqs,
        dtype=torch.int32,
        device="cpu",
        pin_memory=pin_memory,
    )
    runner.kv_compression_prompt_lens_np = runner.kv_compression_prompt_lens_cpu.numpy()
    runner.kv_compression_prompt_lens = torch.zeros(
        max_num_reqs,
        dtype=torch.int32,
        device=device,
    )

    runner.kv_compression_prompt_topk_keep_cpu = torch.zeros(
        max_num_reqs,
        dtype=torch.int32,
        device="cpu",
        pin_memory=pin_memory,
    )
    runner.kv_compression_prompt_topk_keep_np = runner.kv_compression_prompt_topk_keep_cpu.numpy()
    runner.kv_compression_prompt_topk_keep = torch.zeros(
        max_num_reqs,
        dtype=torch.int32,
        device=device,
    )
    runner.kv_compression_prompt_topk_keep_max = None  # type: Optional[int]