runner_prompt_compaction.py 8.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

from typing import Any

import torch

import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
from vllm.v1.kv_cache_interface import (
    AttentionSpec,
    CrossAttentionSpec,
    UniformTypeKVCacheSpecs,
)
from vllm.v1.kv_compression.forward_context import get_kv_compression_prompt_payload


def stash_kv_compression_prompt_payload_to_requests(*, runner: Any) -> None:
    """Persist prompt-end compaction indices from the forward context.

    This is the runner-side half of chunked-prefill scheme 3:
      flash_attn -> forward_context payload -> request state stash ->
      (next step) one-shot KV compaction.
    """
    if not envs.VLLM_ENABLE_KV_COMPRESSION:
        return
    scheduler_config = getattr(runner, "scheduler_config", None)
    if scheduler_config is None or not getattr(scheduler_config,
                                              "enable_chunked_prefill",
                                              False):
        return

    forward_context = get_forward_context()
    payload = get_kv_compression_prompt_payload(forward_context)
    if payload is None:
        return

    req_indices = payload.get("req_indices")
    idx_sorted = payload.get("idx_sorted")
    keep_len = payload.get("keep_len")
    prompt_lens = payload.get("prompt_lens")
    if (req_indices is None or idx_sorted is None or keep_len is None
            or prompt_lens is None):
        return

    input_batch = getattr(runner, "input_batch", None)
    if input_batch is None:
        return
    req_ids = getattr(input_batch, "req_ids", None)
    if req_ids is None:
        return
    requests = getattr(runner, "requests", None)
    if requests is None:
        return

    req_indices_cpu = req_indices.to(device="cpu", dtype=torch.int64).tolist()
    keep_cpu = keep_len.to(device="cpu", dtype=torch.int64).tolist()
    prompt_cpu = prompt_lens.to(device="cpu", dtype=torch.int64).tolist()
    for i, b in enumerate(req_indices_cpu):
        if b < 0 or b >= len(req_ids):
            continue
        req_id = req_ids[b]
        if req_id is None:
            continue
        rs = requests.get(req_id)
        if rs is None:
            continue
        rs.kv_compression_prompt_idx_sorted = idx_sorted[i]
        rs.kv_compression_prompt_keep_len = int(keep_cpu[i])
        rs.kv_compression_prompt_prompt_len = int(prompt_cpu[i])


def maybe_apply_kv_compression_prompt_compaction(*, runner: Any) -> None:
    """Apply one-shot prompt KV compaction before the first decode step."""
    if not envs.VLLM_ENABLE_KV_COMPRESSION:
        return
    if not current_platform.is_cuda_alike():
        return
    scheduler_config = getattr(runner, "scheduler_config", None)
    if scheduler_config is None or not getattr(scheduler_config,
                                              "enable_chunked_prefill",
                                              False):
        return

    input_batch = getattr(runner, "input_batch", None)
    if input_batch is None:
        return
    requests = getattr(runner, "requests", None)
    if requests is None:
        return

    pending_req_ids: list[str] = []
    for req_id in input_batch.req_ids:
        if req_id is None:
            continue
        rs = requests.get(req_id)
        if rs is None:
            continue
        if rs.kv_compression_prompt_idx_sorted is None:
            continue
        # Only apply once the prompt is fully ingested (decode stage).
        if rs.num_computed_tokens < rs.num_prompt_tokens:
            continue
        pending_req_ids.append(req_id)

    if not pending_req_ids:
        return

    device = runner.device
    pending_states: list[tuple[str, torch.Tensor, int]] = []
    for req_id in pending_req_ids:
        rs = requests[req_id]
        keep = rs.kv_compression_prompt_keep_len
        idx = rs.kv_compression_prompt_idx_sorted
        if keep is None or idx is None:
            continue
        keep_i = int(keep)
        if keep_i <= 0:
            # No prompt tokens kept; clear and skip.
            rs.kv_compression_prompt_idx_sorted = None
            rs.kv_compression_prompt_keep_len = None
            rs.kv_compression_prompt_prompt_len = None
            continue
        pending_states.append((req_id, idx, keep_i))

    if not pending_states:
        return

    B = len(pending_states)
    keep_list = [k for _, _, k in pending_states]
    K_max = max(keep_list)
    idx_batch = torch.zeros((B, K_max), device=device, dtype=torch.int32)
    for i, (_, row, k) in enumerate(pending_states):
        idx_batch[i, :k] = row[:k].to(device=device, dtype=torch.int32)
    keep_tensor = torch.tensor(keep_list, device=device, dtype=torch.int32)

    from vllm.v1.kv_compression.kv_cache_triton import (
        front_compact_inplace_fa_triton, make_fa_cache_view)

    kv_cache_config = getattr(runner, "kv_cache_config", None)
    if kv_cache_config is None:
        return

    # Apply compaction to every attention layer's KV cache in-place.
    for group_id, kv_cache_group_spec in enumerate(
            kv_cache_config.kv_cache_groups):
        max_blocks = 0
        for req_id, _, _ in pending_states:
            rs = requests[req_id]
            if group_id >= len(rs.block_ids):
                continue
            max_blocks = max(max_blocks, len(rs.block_ids[group_id]))
        if max_blocks == 0:
            continue

        block_table_cpu = torch.zeros((B, max_blocks),
                                      dtype=torch.int32,
                                      device="cpu")
        for i, (req_id, _, _) in enumerate(pending_states):
            rs = requests[req_id]
            if group_id >= len(rs.block_ids):
                continue
            ids = rs.block_ids[group_id]
            if ids:
                block_table_cpu[i, :len(ids)] = torch.tensor(ids,
                                                            dtype=torch.int32,
                                                            device="cpu")
        block_table = block_table_cpu.to(device=device, non_blocking=True)

        static_forward_context = getattr(
            getattr(runner, "compilation_config", None),
            "static_forward_context",
            None,
        )
        if static_forward_context is None:
            continue

        seen_cache_ptrs: set[int] = set()
        for layer_name in kv_cache_group_spec.layer_names:
            # Skip non-self-attention caches (e.g., encoder/decoder cross-attn)
            # and non-attention cache specs (e.g., Mamba).
            kv_cache_spec = kv_cache_group_spec.kv_cache_spec
            if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
                kv_cache_spec = kv_cache_spec.kv_cache_specs.get(layer_name)
            if kv_cache_spec is None or not isinstance(kv_cache_spec, AttentionSpec):
                continue
            if isinstance(kv_cache_spec, CrossAttentionSpec):
                continue

            layer = static_forward_context.get(layer_name)
            if layer is None:
                continue
            kv_cache_list = getattr(layer, "kv_cache", None)
            if not isinstance(kv_cache_list, list) or not kv_cache_list:
                continue
            kv_cache = kv_cache_list[0]
            if not current_platform.is_rocm():
                if not isinstance(kv_cache, torch.Tensor):
                    continue
                cache_ptr = int(kv_cache.data_ptr())
                if cache_ptr in seen_cache_ptrs:
                    continue
                seen_cache_ptrs.add(cache_ptr)
                key_cache, value_cache = kv_cache.unbind(0)
            else:
                if (not isinstance(kv_cache, (tuple, list))
                        or len(kv_cache) != 2):
                    continue
                key_cache, value_cache = kv_cache
                cache_ptr = int(key_cache.data_ptr())
                if cache_ptr in seen_cache_ptrs:
                    continue
                seen_cache_ptrs.add(cache_ptr)
            k_view, v_view = make_fa_cache_view(key_cache=key_cache,
                                                value_cache=value_cache)
            front_compact_inplace_fa_triton(
                k_view,
                v_view,
                block_table,
                idx_batch,
                keep_tensor,
            )

    # Clear pending state after successful compaction.
    for req_id, _, _ in pending_states:
        rs = requests.get(req_id)
        if rs is None:
            continue
        rs.kv_compression_prompt_idx_sorted = None
        rs.kv_compression_prompt_keep_len = None
        rs.kv_compression_prompt_prompt_len = None