runner_prepare.py 5.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

from typing import Optional

import numpy as np

import vllm.envs as envs
from vllm.v1.kv_compression.budget import (compute_prompt_topk_keep_total,
                                           compute_topk_budget_step)


def prepare_kv_compression_for_step(
    *,
    num_reqs: int,
    total_num_scheduled_tokens: int,
    num_scheduled_tokens: np.ndarray,  # [B] int32
    cu_num_tokens: np.ndarray,  # [B] int64/int32 cumulative scheduled tokens
    req_indices: np.ndarray,  # [T] int64, request index per token
    arange: np.ndarray,  # [T] int64, position-within-request per token
    num_computed_tokens_cpu: np.ndarray,  # [max_reqs] int32/int64
    num_prompt_tokens: np.ndarray,  # [max_reqs] int32/int64
    num_kv_tokens_cpu: np.ndarray,  # [max_reqs] int32/int64
    kv_positions_np: np.ndarray,  # [T] int64 (out)
    must_keep_np: np.ndarray,  # [T] bool (out; scheme 1/2 only)
    topk_budget_np: np.ndarray,  # [B] int32 (out; scheme 1/2 only)
    prompt_end_np: np.ndarray,  # [B] bool (out; scheme 3 only)
    prompt_lens_np: np.ndarray,  # [B] int32 (out; scheme 3 only)
    prompt_topk_keep_np: np.ndarray,  # [B] int32 (out; scheme 3 only)
    chunked_prefill_enabled: bool,
) -> tuple[bool, Optional[int]]:
    """Prepare KV compression metadata for a single model step (CPU-side).

    Fills:
    - `kv_positions_np`: per-token KV write positions (decoupled from logical
      RoPE positions).
    - Scheme 3 (chunked prefill): `prompt_end/prompt_lens/prompt_topk_keep`.
    - Scheme 1/2 (non-chunked): `must_keep/topk_budget`.

    Returns:
      (needs_compaction, prompt_topk_keep_max)
    """
    if total_num_scheduled_tokens <= 0 or num_reqs <= 0:
        return False, None

    # KV positions (where scheduled tokens are written before optional
    # compaction).
    np.add(num_kv_tokens_cpu[req_indices], arange, out=kv_positions_np)

    prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
    prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
    protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
    protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
    keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN

    if chunked_prefill_enabled:
        # Scheme 3: with chunked prefill, defer compaction until after the full
        # prompt is ingested. Otherwise, the next prefill chunk would attend to
        # a truncated history and quality can collapse.
        prompt_end_np.fill(False)
        prompt_lens_np.fill(0)
        prompt_topk_keep_np.fill(0)

        for req_idx in range(num_reqs):
            qlen = int(num_scheduled_tokens[req_idx])
            if qlen <= 0:
                continue
            base_pos = int(num_computed_tokens_cpu[req_idx])
            prompt_len = int(num_prompt_tokens[req_idx])
            end_pos = base_pos + qlen
            ends_prompt = (base_pos < prompt_len) and (end_pos >= prompt_len)
            if not ends_prompt:
                continue

            prompt_end_np[req_idx] = True
            prompt_lens_np[req_idx] = prompt_len
            prompt_topk_keep_np[req_idx] = compute_prompt_topk_keep_total(
                prompt_len=prompt_len,
                protected_prefix=protected_prefix,
                protected_suffix=protected_suffix,
                keep_last_token=keep_last,
                prompt_ratio=prompt_ratio,
                prompt_budget=prompt_budget,
            )

        prompt_topk_keep_max = int(prompt_topk_keep_np[:num_reqs].max())
        return False, prompt_topk_keep_max

    # Scheme 1/2: per-step compaction within the scheduled segment.
    must_keep_np.fill(False)
    topk_budget_np.fill(0)

    for req_idx in range(num_reqs):
        qlen = int(num_scheduled_tokens[req_idx])
        if qlen <= 0:
            continue
        start = 0 if req_idx == 0 else int(cu_num_tokens[req_idx - 1])
        end = int(cu_num_tokens[req_idx])
        assert end - start == qlen

        base_pos = int(num_computed_tokens_cpu[req_idx])
        prompt_len = int(num_prompt_tokens[req_idx])
        end_pos = base_pos + qlen
        pos_in_req = arange[start:end].astype(np.int64, copy=False)
        pos = base_pos + pos_in_req

        prompt_mask = pos < prompt_len
        # Decode tokens are always kept.
        must_keep = ~prompt_mask

        if np.any(prompt_mask):
            suffix_start = max(prompt_len - protected_suffix, 0)
            must_keep |= prompt_mask & (pos < protected_prefix)
            must_keep |= prompt_mask & (pos >= suffix_start)
            if keep_last:
                last = prompt_len - 1
                if base_pos <= last < end_pos:
                    must_keep[last - base_pos] = True

            topk_budget_np[req_idx] = compute_topk_budget_step(
                prompt_len=prompt_len,
                start_pos=base_pos,
                end_pos=end_pos,
                protected_prefix=protected_prefix,
                protected_suffix=protected_suffix,
                keep_last_token=keep_last,
                prompt_ratio=prompt_ratio,
                prompt_budget=prompt_budget,
            )

        must_keep_np[start:end] = must_keep

    # Decode-only fast path: if all scheduled tokens are unconditionally kept
    # and there is no Top-K budget, KV compaction is a no-op and can be skipped.
    needs_compaction = (not must_keep_np.all()) or (topk_budget_np > 0).any()
    return bool(needs_compaction), None