slot_mapping.py 4.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

from typing import Optional

import torch

from vllm.v1.kv_compression.topk_select import (_packed_varlen_coords,
                                                _topk_keep_mask_and_local_rank)


def _dst_slots_from_keep_mask_and_local_rank(
    *,
    keep_mask: torch.Tensor,  # [T] bool
    local_rank: torch.Tensor,  # [T] int64
    seq_lens: torch.Tensor,  # [B] int32
    lengths: torch.Tensor,  # [B] int64
    req_ids: torch.Tensor,  # [T] int64
    block_table: torch.Tensor,  # [B, max_blocks] int32
    block_size: int,
) -> torch.Tensor:
    """Convert keep_mask/local_rank into a per-token KV destination slot."""
    device = keep_mask.device
    T = int(keep_mask.numel())
    dst_slots = torch.full((T, ), -1, device=device, dtype=torch.int64)
    if T == 0:
        return dst_slots

    B = int(seq_lens.numel())
    if B == 0:
        return dst_slots

    # Base KV cache position for this step (i.e., KV length before writing this
    # scheduled segment). With KV compression enabled, seq_lens is derived from
    # num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len.
    base_kv = (seq_lens[:B].to(torch.long) - lengths.to(torch.long)).clamp_min(0)
    base_kv_per_token = base_kv.index_select(0, req_ids)  # [T]
    dest_pos = base_kv_per_token + local_rank  # [T]
    dest_block_idx = dest_pos // block_size
    dest_off = dest_pos - dest_block_idx * block_size

    # Safe indexing for dropped tokens (ignored by keep_mask anyway).
    max_blocks = int(block_table.shape[1])
    dest_block_idx_safe = dest_block_idx.clamp_(0, max_blocks - 1).to(torch.long)
    block_nums = block_table[req_ids, dest_block_idx_safe]
    dest_slot = block_nums.to(torch.long) * block_size + dest_off
    return torch.where(keep_mask, dest_slot.to(torch.int64), dst_slots)


def topk_kv_compact_slot_mapping(
    *,
    token_scores: Optional[torch.Tensor],  # [T] float32
    must_keep: torch.Tensor,  # [T] bool
    topk_budget: torch.Tensor,  # [B] int32
    query_start_loc: torch.Tensor,  # [B+1]
    seq_lens: torch.Tensor,  # [B] int32
    block_table: torch.Tensor,  # [B, max_blocks]
    block_size: int,
    max_query_len: Optional[int] = None,
    topk_budget_max: Optional[int] = None,
) -> torch.Tensor:
    """Build a per-token destination slot mapping for KV compaction.

    Returns a tensor `dst_slots` of shape [T] where:
    - `dst_slots[i] >= 0` indicates token i should be kept and rewritten to
      that KV cache slot.
    - `dst_slots[i] == -1` indicates token i is dropped after the step.
    """
    device = must_keep.device
    T = int(must_keep.numel())
    B = int(topk_budget.numel())

    dst_slots = torch.full((T, ), -1, device=device, dtype=torch.int64)
    if T == 0 or B == 0:
        return dst_slots

    starts, _, lengths, req_ids, pos_in_req = _packed_varlen_coords(
        cu_seqlens=query_start_loc,
        total_tokens=T,
    )
    if lengths.numel() == 0:
        return dst_slots

    # Prefer the CPU-known max query length (piecewise graph), to avoid
    # device->host synchronization.
    L_max = int(max_query_len) if max_query_len is not None else int(
        lengths.max().item())
    if L_max <= 0:
        return dst_slots

    keep_mask, local_rank, _ = _topk_keep_mask_and_local_rank(
        token_scores=token_scores,
        must_keep=must_keep,
        topk_budget=topk_budget,
        starts=starts,
        lengths=lengths,
        req_ids=req_ids,
        pos_in_req=pos_in_req,
        max_len=L_max,
        topk_budget_max=topk_budget_max,
    )

    return _dst_slots_from_keep_mask_and_local_rank(
        keep_mask=keep_mask,
        local_rank=local_rank,
        seq_lens=seq_lens[:B],
        lengths=lengths,
        req_ids=req_ids,
        block_table=block_table,
        block_size=int(block_size),
    )


def kv_compaction_dst_rewrite_mapping(
    *,
    dst_slots: torch.Tensor,  # [T] int64
    src_slots: torch.Tensor,  # [T] int64
) -> torch.Tensor:
    """Filter a dst slot mapping so only moved kept tokens are rewritten.

    Non-rewrite tokens are marked as -1, which the cache kernels treat as
    padding and skip.
    """
    rewrite_mask = (dst_slots >= 0) & (dst_slots != src_slots)
    return torch.where(rewrite_mask, dst_slots, -1)