common.py 12.2 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from abc import ABC, abstractmethod
import os
from typing import Optional

import torch

from vllm.kvprune.kv_cache.store_kv_cache import prefill_store_topk_kv


class BaseCompressionMethod(ABC):
    """
    Abstract interface for KV cache compression methods.

    A compression method is implemented as a pair of optional scoring phases
    that run before and after rotary position embedding (RoPE) is applied:

      1. ``pre_rope_scoring`` operates on pre-RoPE Q/K.

      2. ``post_rope_scoring`` operates on post-RoPE Q/K and can either:
         - refine / reweight the pre-RoPE scores, or
         - compute potentially position-aware.

    Concrete subclasses are expected to implement both
    static methods and return a single tensor of scores (or ``None`` if the
    phase is a no-op), which the caller can then feed into the shared
    “scores → top-k indices → KV extraction” pipeline.
    """

    @staticmethod
    @abstractmethod
    def pre_rope_scoring(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        context,
    ) -> Optional[torch.Tensor]:
        """
        Compute per-token importance scores from pre-RoPE queries/keys.

        Args:
            :param q:
                Pre-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
            :param k:
                Pre-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
            :param v:
                Value tensor. Shape ``[total_tokens, HKV, D]```
            :param context:
                vllm.kvprune.utils.context.Context object carrying additional metadata,
                such as batch mappings or temporary buffers

        Returns:
            :return Optional[torch.Tensor]:
                A tensor of scores (e.g. per-token, per-head importance values)
                to be passed to ``post_rope_scoring`` or directly into the
                top-k selection step. If this phase is a no-op, implementations
                should return ``None``. Shape ``[total_tokens, HKV]```.
        """
        pass

    @staticmethod
    @abstractmethod
    def post_rope_scoring(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        pre_rope_scores: Optional[torch.Tensor],
        context,
    ) -> Optional[torch.Tensor]:
        """
        Compute or refine importance scores from post-RoPE queries/keys.

        This method is called after rotary embeddings have been applied. It can
        optionally use both the post-RoPE Q/K and any scores produced by
        ``pre_rope_scoring`` to produce final scores used for token selection.

        Common patterns include:
          * Using ``pre_rope_scores`` as a base signal and applying a
            position-aware correction.
          * Only computing scores that depend on absolute or relative positions.
          * Simply passing through ``pre_rope_scores`` unchanged.

        Args:
            :param q:
                Post-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
            :param k:
                Post-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
            :param pre_rope_scores:
                Optional scores returned by ``pre_rope_scoring``. May be
                ``None`` if the pre-RoPE phase returned None.
            :param v:
                Value tensor. Shape ``[total_tokens, HKV, D]```
            :param context:
                vllm.kvprune.utils.context.Context object carrying additional metadata,
                such as batch mappings or temporary buffers
        Returns:
            :return Optional[torch.Tensor]:
                Final importance scores to be consumed by the compression
                pipeline (for top-k token selection). If this phase is a
                no-op, implementations may return ``pre_rope_scores``. If
                None is returned, no compression will be applied.
        """
        pass


class NoCompression(BaseCompressionMethod):
    """
    Trivial compression method that disables KV cache compression.
    """

    @staticmethod
    def pre_rope_scoring(
        q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
    ) -> Optional[torch.Tensor]:
        return None

    @staticmethod
    def post_rope_scoring(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        pre_rope_scores: torch.Tensor,
        context,
    ) -> Optional[torch.Tensor]:
        return pre_rope_scores


def extract_and_store_top_kv(
    scores: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_k_len: int,
    top_k: int,
    H: int,
    new_keys: torch.Tensor,  # [N_total, H, D]
    new_vals: torch.Tensor,  # [N_total, H, D]
    num_tokens_to_retain: torch.Tensor,  # [B] int32
    page_table: torch.Tensor,  # [B_total, H, N_LOGICAL_PAGES_MAX] int32
    batch_mapping: torch.Tensor,  # [B] int32 (local -> true batch rows)
    bh_lens: torch.Tensor,  # [B, H] int32 (contiguous), UPDATED atomically
    k_cache: torch.Tensor,  # [N_PAGES * PAGE_SIZE, D]
    v_cache: torch.Tensor,  # [N_PAGES * PAGE_SIZE, D]
    PAGE_SIZE: int,
    PAD_TO_PAGE_SIZE: bool = True,
    K_TILE: int = 16,
    padding: float = -float("inf"),
):
    """helper method to extract and store top-k indices into KV cache (so they can be executed in a single stream)"""
chenzk's avatar
chenzk committed
147
    assert num_tokens_to_retain is not None, "num_tokens_to_retain must be set"
chenzk's avatar
chenzk committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    # per_head: per-head highest-scoring remaining tokens for page padding.
    # global_scan: legacy global ranking order, padded by scanning forward in-kernel.
    padding_mode = os.environ.get(
        "VLLM_KVPRUNE_PADDING_MODE", "per_head"
    ).strip().lower()
    max_pairs_per_batch = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).to(
        device=num_tokens_to_retain.device, dtype=num_tokens_to_retain.dtype
    ) * H
    num_tokens_to_retain = torch.minimum(num_tokens_to_retain, max_pairs_per_batch)

    indices_topk, candidate_counts = scores_to_retain_indices(
        scores,
        cu_seqlens_k=cu_seqlens_k,
        max_k_len=max_k_len,
        top_k=top_k,
        H=H,
        num_tokens_to_retain=num_tokens_to_retain,
        page_size=PAGE_SIZE,
        padding_mode=padding_mode,
        padding=padding,
    )
    prefill_store_topk_kv(
        new_keys=new_keys,
        new_vals=new_vals,
        indices_topk=indices_topk,
        candidate_counts=candidate_counts,
        num_tokens_to_retain=num_tokens_to_retain,
        page_table=page_table,
        batch_mapping=batch_mapping,
        bh_lens=bh_lens,
        k_cache=k_cache,
        v_cache=v_cache,
        cu_seqlens_k=cu_seqlens_k,
        PAGE_SIZE=PAGE_SIZE,
        PAD_TO_PAGE_SIZE=PAD_TO_PAGE_SIZE,
        K_TILE=K_TILE,
    )


def scores_to_retain_indices(
    scores: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_k_len: int,
    top_k: int,
    H: int,
    num_tokens_to_retain: torch.Tensor,
    page_size: int,
    padding_mode: str = "per_head",
    padding: float = -float("inf"),
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Build candidate token-head indices for compression writes.

    For each batch element, this helper returns:

    1. a prefix of the true global top-k ``(token, head)`` pairs, and
    2. a suffix of additional padding candidates according to ``padding_mode``:
       - ``per_head``: choose each head's highest-scoring remaining tokens.
       - ``global_scan``: keep the legacy global ranking order and let the
         store kernel scan forward until it finds enough entries for that head.

    The page-alignment requirement comes from the paged KV cache, but the
    padding candidates themselves do not need to be discovered inside the
    Triton store kernel. Choosing them here avoids the older "scan the global
    candidate list until you stumble across enough entries for this head"
    behavior, which could distort the retained set even though the page-table
    / reclaim logic only cares about the final per-head counts.

    Args:
        :param scores:
            Tensor of shape ``[N_total, HKV]`` containing scores for each
            (token, head) pair in packed varlen format.
        :param cu_seqlens_k:
            Tensor of shape ``[B + 1]`` (int32) with cumulative key sequence
            lengths for each batch element. The total number of tokens
            satisfies ``N_total = cu_seqlens_k[-1]``.
        :param max_k_len:
            Maximum key sequence length across the batch (i.e.
            ``max_b seqlen_k[b]``). Used to allocate the padded buffer.
        :param top_k:
            Kept for API compatibility with the caller. The retained prefix is
            determined by ``num_tokens_to_retain``; the tail is built from
            per-head padding needs.
        :param H:
            Number of key heads; must match ``scores.shape[1]``.
        :param num_tokens_to_retain:
            The true number of token-head pairs to keep for each batch element
            before page padding.
        :param page_size:
            Page size of the KV cache. Determines how many extra candidates
            are needed per head to reach page alignment.
        :param padding_mode:
            ``per_head`` for per-head optimal padding candidates, or
            ``global_scan`` for the legacy "scan the global ranking" behavior.
        :param padding:
            Kept for backward compatibility; no longer used.

    Returns:
        A tuple ``(indices, counts)`` where:

        - ``indices`` is ``[B, MAX_SEL]`` int64, containing global flattened
          ``token * H + head`` indices.
        - ``counts`` is ``[B]`` int32, the number of valid candidates for each
          batch row inside ``indices``.
    """
    del max_k_len, top_k, padding

    B, device = cu_seqlens_k.numel() - 1, scores.device
    row_indices: list[torch.Tensor] = []
    candidate_counts = torch.zeros(B, dtype=torch.int32, device=device)
    if padding_mode not in ("per_head", "global_scan"):
        raise ValueError(
            "Unsupported VLLM_KVPRUNE_PADDING_MODE. "
            f"Expected 'per_head' or 'global_scan', got {padding_mode!r}."
        )

    for b in range(B):
        s = int(cu_seqlens_k[b].item())
        e = int(cu_seqlens_k[b + 1].item())
        seq_len = e - s
        total_pairs = seq_len * H
        keep = min(int(num_tokens_to_retain[b].item()), total_pairs)
        if total_pairs == 0 or keep == 0:
            row_indices.append(torch.empty(0, dtype=torch.int64, device=device))
            continue

        seq_scores = scores[s:e, :]  # [L, H]
        flat_scores = seq_scores.reshape(-1)

        if padding_mode == "global_scan":
            row = torch.argsort(flat_scores, dim=0, descending=True)
        else:
            prefix = torch.topk(
                flat_scores, k=keep, dim=0, largest=True, sorted=True
            ).indices

            selected_flat = torch.zeros(total_pairs, dtype=torch.bool, device=device)
            selected_flat[prefix] = True
            selected_mask = selected_flat.view(seq_len, H)

            head_counts = torch.bincount(prefix % H, minlength=H)
            need_per_head = (page_size - (head_counts % page_size)) % page_size
            max_extra_per_head = seq_len - head_counts
            need_per_head = torch.minimum(need_per_head, max_extra_per_head)

            tails: list[torch.Tensor] = []
            for h in range(H):
                need = int(need_per_head[h].item())
                if need <= 0:
                    continue
                rem_scores_h = seq_scores[:, h].masked_fill(
                    selected_mask[:, h], -torch.inf
                )
                tail_tok = torch.topk(
                    rem_scores_h, k=need, dim=0, largest=True, sorted=True
                ).indices
                tails.append(tail_tok * H + h)

            if tails:
                row = torch.cat([prefix, *tails], dim=0)
            else:
                row = prefix

        row_indices.append(row + s * H)
        candidate_counts[b] = int(row.numel())

    max_sel = max((int(x.numel()) for x in row_indices), default=0)
    if max_sel == 0:
        return (
            torch.zeros((B, 1), dtype=torch.int64, device=device),
            candidate_counts,
        )

    indices = torch.zeros((B, max_sel), dtype=torch.int64, device=device)
    for b, row in enumerate(row_indices):
        if row.numel():
            indices[b, : row.numel()] = row
    return indices, candidate_counts