scheduler_accounting.py 4.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

from typing import Any

import vllm.envs as envs
from vllm.v1.kv_compression.budget import (compute_prompt_keep_len,
                                           compute_topk_budget_step,
                                           count_prompt_must_keep_in_range)


def maybe_init_num_kv_tokens_on_running_transition(
    *,
    request: Any,
    num_computed_tokens: int,
    kv_compression_enabled: bool,
) -> None:
    """Initialize `request.num_kv_tokens` when a request becomes RUNNING.

    When KV compression is disabled, KV length always matches the logical length.

    When KV compression is enabled, most requests start with
    `num_computed_tokens == 0` so no init is needed. However, connector/cache-hit
    paths can start a RUNNING request with precomputed tokens; in that case we
    must ensure KV length is initialized to avoid incorrect KV write offsets.
    """
    if not kv_compression_enabled:
        request.num_kv_tokens = num_computed_tokens
        return
    if getattr(request, "num_kv_tokens", 0) == 0 and num_computed_tokens > 0:
        request.num_kv_tokens = num_computed_tokens


def update_num_kv_tokens_after_schedule(
    *,
    request: Any,
    start_pos: int,
    num_scheduled_token: int,
    chunked_prefill_enabled: bool,
    kv_compression_enabled: bool,
) -> None:
    """Advance `request.num_kv_tokens` after scheduling a step."""
    if num_scheduled_token <= 0:
        return

    if not kv_compression_enabled:
        request.num_kv_tokens += num_scheduled_token
        return

    # When KV compression is enabled, only keep a subset of prompt tokens.
    # Decode tokens are always kept.
    prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
    prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
    protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
    protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
    keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN

    end_pos = start_pos + int(num_scheduled_token)
    prompt_end = int(getattr(request, "num_prompt_tokens", 0))

    # Chunked prefill: do not change the prompt KV length mid-prefill.
    # Otherwise, the next prefill chunk would attend to a truncated history
    # (semantic change / quality collapse). Instead, keep the full prompt KV
    # until the prompt is fully ingested, then apply a one-shot prompt
    # compaction before decode.
    if chunked_prefill_enabled:
        if start_pos >= prompt_end:
            # Decode token(s): keep all.
            request.num_kv_tokens += num_scheduled_token
            return

        if end_pos < prompt_end:
            # Prompt is still being ingested: keep all tokens for now.
            request.num_kv_tokens += num_scheduled_token
            return

        # This step finishes the prompt (and may include decode tokens in rare
        # cases). Apply the final prompt compression length.
        kept_prompt_total = compute_prompt_keep_len(
            prompt_len=prompt_end,
            protected_prefix=protected_prefix,
            protected_suffix=protected_suffix,
            keep_last_token=keep_last,
            prompt_ratio=prompt_ratio,
            prompt_budget=prompt_budget,
        )
        kept_decode = max(0, end_pos - max(start_pos, prompt_end))
        request.num_kv_tokens = kept_prompt_total + kept_decode
        return

    # Decode token(s): keep all.
    decode_start = max(start_pos, prompt_end)
    kept_decode = max(0, end_pos - decode_start)

    kept_prompt_must_keep = count_prompt_must_keep_in_range(
        prompt_len=prompt_end,
        start_pos=start_pos,
        end_pos=end_pos,
        protected_prefix=protected_prefix,
        protected_suffix=protected_suffix,
        keep_last_token=keep_last,
    )
    kept_prompt_topk = compute_topk_budget_step(
        prompt_len=prompt_end,
        start_pos=start_pos,
        end_pos=end_pos,
        protected_prefix=protected_prefix,
        protected_suffix=protected_suffix,
        keep_last_token=keep_last,
        prompt_ratio=prompt_ratio,
        prompt_budget=prompt_budget,
    )

    request.num_kv_tokens += (kept_decode + kept_prompt_must_keep +
                              kept_prompt_topk)