compaction_step.py 2.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

import torch

import vllm.envs as envs
from vllm.v1.kv_compression.slot_mapping import topk_kv_compact_slot_mapping
from vllm.v1.kv_compression.snapkv_score import snapkv_like_token_scores


def snapkv_window_for_topk_budget(
    *,
    topk_budget: torch.Tensor,  # [B] int32
    window: int,
) -> torch.Tensor:
    """Build per-request SnapKV window sizes for mixed batches.

    Requests with a zero Top-K budget do not need token scores; setting their
    window to 0 lets the Triton scoring kernel early-return.
    """
    return torch.where(
        topk_budget > 0,
        torch.full_like(topk_budget, int(window)),
        torch.zeros_like(topk_budget),
    )


def compute_compact_dst_slots_for_step(
    *,
    query: torch.Tensor,  # [T, Hq, D] for this step
    key: torch.Tensor,  # [T, Hkv, D] for this step
    query_start_loc: torch.Tensor,  # [B+1]
    seq_lens: torch.Tensor,  # [B] int32
    block_table: torch.Tensor,  # [B, max_blocks]
    block_size: int,
    must_keep: torch.Tensor,  # [T] bool
    topk_budget: torch.Tensor,  # [B] int32
    topk_budget_max: int,
    max_query_len: int,
    sm_scale: float,
) -> torch.Tensor:
    """Compute per-token KV compaction destinations for one step."""
    token_scores = None
    if int(topk_budget_max) > 0:
        w = snapkv_window_for_topk_budget(
            topk_budget=topk_budget,
            window=int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW),
        )
        token_scores = snapkv_like_token_scores(
            query=query,
            key=key,
            query_start_loc=query_start_loc,
            window=w,
            sm_scale=float(sm_scale),
        )

    return topk_kv_compact_slot_mapping(
        token_scores=token_scores,
        must_keep=must_keep,
        topk_budget=topk_budget,
        query_start_loc=query_start_loc,
        seq_lens=seq_lens,
        block_table=block_table,
        block_size=int(block_size),
        max_query_len=int(max_query_len),
        topk_budget_max=int(topk_budget_max),
    )