page_table.py 12.3 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import heapq
import logging
from enum import Enum, auto
from typing import List, Optional, Union

import torch
from vllm.kvprune.config.constants import RESERVED_BATCH
from vllm.kvprune.kv_cache.write_page_table import scatter_to_page_table

logger = logging.getLogger(__name__)


def cdiv(a, b):
    return (a + b - 1) // b


def next_multiple(a, b):
    return cdiv(a, b) * b


class KVAllocationStatus(Enum):
    EXCEEDS_MAX_SEQUENCE_LENGTH = auto()
    EXCEEDS_CURRENTLY_AVAILABLE_PAGES = auto()
    EXCEEDS_MAX_NUM_BATCHES = auto()
    SUCCESS = auto()


class PagedKVCache(torch.nn.Module):
    """
    Global paged KV cache.
    This module manages:
      * A global K/V backing buffer for all layers:
          ``kv_cache[2, num_layers, n_pages * page_size, head_dim]``,
        where the first dimension indexes K vs V.
      * A per-layer page table:
          ``page_table[num_layers, max_num_seqs, H_kv, max_pages_per_head]``,
        mapping logical (batch, kv-head, logical_page) to a physical page ID
        in the global K/V buffer.
      * Per-layer, per-(batch, kv-head) logical sequence lengths
        ``bh_seq_lens[num_layers, max_num_seqs, H_kv]`` (in tokens), and
        the number of allocated pages ``bh_num_pages`` for each (layer, batch,
        head).
      * A page allocator implemented as a min-heap of free physical pages
        per layer, plus free batch indices.
    Pages are of fixed size ``page_size`` tokens.
    Args:
        :param num_layers:
            Number of transformer layers that will use this cache.
        :param max_logical_pages_per_head:
            Maximum number of logical pages that can be assigned to a single
            (batch, kv-head) pair.
        :param num_pages:
            Total number of physical pages available in the global cache per
            layer. The global K/V buffers are of length
            ``num_pages * page_size`` along the token dimension.
        :param  page_size:
            Number of tokens stored per page.
        :param H_kv:
            Number of KV heads per layer.
        :param  head_dim:
            Head dimension for K/V.
        :param max_num_batches:
            Maximum number of concurrent batches / sequences supported. One
            batch index is reserved for internal use (``RESERVED_BATCH``).
        :param dtype:
            Data type of K/V entries (e.g. ``torch.float16`` or ``torch.bfloat16``).
        :param device:
            Device on which to allocate the cache (string, torch.device, or
            int; defaults to ``"cuda"``).
    """

    def __init__(
        self,
        num_layers: int,
        max_logical_pages_per_head: int,
        num_pages: int,
        page_size: int,  # tokens per page
        H_kv: int,
        head_dim: int,
        max_num_batches: int,
        dtype: torch.dtype,
        device: Union[str, torch.device, int] = "cuda",
    ):
        super().__init__()
        self.n_pages = num_pages
        self.num_layers = num_layers
        self.page_size: int = int(page_size)
        self.H_kv = int(H_kv)
        self.max_pages_per_head = max_logical_pages_per_head
        max_num_batches += 1
        self.max_num_batches = max_num_batches
        self.head_dim = head_dim
        cache_shape = (2, num_layers, num_pages * page_size, head_dim)
        self.kv_cache = torch.empty(cache_shape, dtype=dtype, device=device)

        self.page_table = torch.empty(
            (num_layers, max_num_batches, H_kv, self.max_pages_per_head),
            device=device,
            dtype=torch.int32,
        )

        # Per-(batch, head) logical seq length (tokens)
        self.bh_seq_lens = torch.zeros(
            (num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
        )
        # self._bh_seq_lens_cpu_buffer = torch.zeros((num_layers, H_kv), device="cpu", dtype=torch.int32)
        self.bh_num_pages = torch.zeros(
            (num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
        )

        # Page allocator (min-heap of free physical pages)
        self.free_pages: List[List[int]] = [
            list(range(num_pages)) for _ in range(num_layers)
        ]
        for free_pages in self.free_pages:
            heapq.heapify(free_pages)
        # batch zero is reserved
        self.free_batches: List[int] = list(reversed(range(max_num_batches)))
        self.free_batches.remove(RESERVED_BATCH)
        # Record of physical page ids owned by a batch (for freeing)
        self.pages_indices_per_batch: List[List[set[int]]] = [
            [set() for _ in range(num_layers)] for _ in range(max_num_batches)
        ]

    def new_batch(self) -> Optional[int]:
        """
        Reserve a new batch slot.
        A batch slot corresponds to a row in ``bh_seq_lens`` /
        ``bh_num_pages`` and a slice in ``page_table`` for all layers and KV
        heads. This method checks whether a free batch index is available, and
        whether each layer has at least ``H_kv`` free pages remaining.
        If both checks pass, it returns a batch index and removes it from
        ``free_batches``. Otherwise, it returns ``None``.

        Returns:
            :return Optional[int]:
                Newly reserved batch index, or ``None`` if no capacity is
                available.
        """
        if self.free_batches and all([self.H_kv <= len(fp) for fp in self.free_pages]):
            return self.free_batches.pop()
        return None

    def reserve_tokens(self, batch_index: int, add_tokens: int) -> KVAllocationStatus:
        """
        Ensure enough pages are allocated to handle ``add_tokens`` new tokens.
        Args:
            :param batch_index:
                Batch index to reserve space for.
            :param  add_tokens:
                Number of additional tokens to reserve capacity for.
                All heads in this batch and all layers reserve
                the same number of extra tokens.
        Returns:
            :return bool:
                ``True`` if the reservation succeeds; ``False`` otherwise .
        """
        cur_bh_lens = self.bh_seq_lens[:, batch_index]  # [L, H]
        curr_pages = self.bh_num_pages[:, batch_index]  # [L, H]
        curr_cap_tokens = curr_pages * self.page_size  # [L, H]
        need_tokens = cur_bh_lens + add_tokens  # [L, H]
        if (need_tokens <= curr_cap_tokens).all():
            return KVAllocationStatus.SUCCESS
        missing_tokens = need_tokens - curr_cap_tokens
        add_pages = cdiv(missing_tokens, self.page_size)
        new_total_pages = curr_pages + add_pages
        if (new_total_pages > self.max_pages_per_head).any():
            return KVAllocationStatus.EXCEEDS_MAX_SEQUENCE_LENGTH
        # CPU work
        pages_per_layer_cpu = add_pages.sum(dim=-1).tolist()
        new_phys_pages = []
        for layer_index in range(self.num_layers):
            if pages_per_layer_cpu[layer_index] > len(self.free_pages[layer_index]):
                return KVAllocationStatus.EXCEEDS_CURRENTLY_AVAILABLE_PAGES
        for layer_index in range(self.num_layers):
            this_layer_pages = [
                heapq.heappop(self.free_pages[layer_index])
                for _ in range(pages_per_layer_cpu[layer_index])
            ]
            self.pages_indices_per_batch[batch_index][layer_index] |= set(
                this_layer_pages
            )
            new_phys_pages.extend(this_layer_pages)

        new_phys_pages = torch.tensor(new_phys_pages, dtype=torch.int32, device="cuda")

        scatter_to_page_table(
            add_pages=add_pages,
            new_phys_pages=new_phys_pages,
            curr_pages=curr_pages,
            page_table=self.page_table[:, batch_index],
            max_pages_per_head=self.max_pages_per_head,
        )

        self.bh_num_pages[:, batch_index, :] = new_total_pages.to(
            self.bh_num_pages.dtype
        )
        return KVAllocationStatus.SUCCESS

    def reclaim_pages(
        self,
        batch_index: int,
        future_reserve_tokens: int = 0,
    ):
        """
        Reclaim unused pages for a single batch index. This shrinks the KV
        allocation for the batch down to the minimum number of pages needed
        to hold the current (plus optional future) sequence length.

        Args:
            :param batch_index:
                Batch index whose pages should be compacted.
            :param future_reserve_tokens:
                Optional number of extra tokens to keep capacity for, beyond
                the current sequence length. This can reduce churn when
                sequences are expected to grow slightly in the near future.

        Returns:
            :return int:
                Approximate number of bytes freed across both K and V.
        """
        device = self.bh_seq_lens.device
        L, B, H = self.bh_seq_lens.shape
        assert 0 <= batch_index < B

        seq = self.bh_seq_lens[:, batch_index, :] + future_reserve_tokens  # [L, H]
        alloc = self.bh_num_pages[:, batch_index, :]  # [L, H]
        pt = self.page_table[:, batch_index, :, :].reshape(-1)  # [L, H, P]

        # Compute used pages: ceil_div(seq, page_size), clamped into [0, alloc]
        used_pages = cdiv(seq, self.page_size)
        used_pages = torch.minimum(used_pages, alloc)

        # page indices [0..P-1], broadcasted over [L, H, P]
        p = torch.arange(
            self.max_pages_per_head, device=device, dtype=torch.int32
        ).view(1, 1, self.max_pages_per_head)

        # allocated: p < alloc
        alloc_mask = p < alloc.unsqueeze(-1)  # [L, H, P]
        # to free: allocated and p in [used_pages, alloc)
        free_mask = alloc_mask & (p >= used_pages.unsqueeze(-1))
        free_mask_flat = free_mask.view(-1)  # [L*H*P]
        if not free_mask_flat.any():
            return 0

        idx = free_mask_flat.nonzero(as_tuple=False).squeeze(
            -1
        )  # indices of freed slots

        # Freed physical page ids
        freed_pages = pt[idx]
        # Compute layer index for each freed slot:
        # layout is [L, H, P] 鈫?flat index = ((l * H) + h) * P + p
        freed_layers = (idx // (H * self.max_pages_per_head)).to(torch.int32)
        freed_pages = freed_pages.tolist()
        layer_mapping = freed_layers.tolist()
        self.bh_num_pages[:, batch_index, :] = used_pages
        for page, layer in zip(freed_pages, layer_mapping):
            self.pages_indices_per_batch[batch_index][layer].remove(page)
            heapq.heappush(self.free_pages[layer], page)
        approximate_bytes_freed = (
            len(freed_pages)
            * (self.page_size * self.head_dim * self.kv_cache.element_size())
            * 2
        )  # multiply for two for K + V
        return approximate_bytes_freed

    def _free_batch_layer(self, layer_index: int, batch_index: int) -> None:
        """
        Free all pages belonging to batch_index and reset its metadata.
        """
        # Return pages to the global heap
        for phys in self.pages_indices_per_batch[batch_index][layer_index]:
            heapq.heappush(self.free_pages[layer_index], int(phys))

        self.pages_indices_per_batch[batch_index][layer_index] = set()

    def free_batch(self, batch_index: int) -> None:
        """
        Free all resources associated with a batch index.
        Args:
            :param batch_index:
                Batch index to release. Must have been previously allocated
                via :meth:`new_batch`.
        """
        for layer in range(self.num_layers):
            self._free_batch_layer(layer, batch_index)
        self.bh_seq_lens[:, batch_index].zero_()
        self.bh_num_pages[:, batch_index].zero_()
        self.free_batches.append(batch_index)

    def layer_slices(self, layer: int):
        """
        Return layer-local views needed by the attention module.

        For a given ``layer`` index, this method returns the slices of the
        global K/V cache, page table, and per-(batch, head) sequence lengths
        corresponding to that layer.
        Args:
            :param layer:
                Layer index ``l`` in ``[0, num_layers)``.

        Returns:
            :return Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
                ``(k, v, pt, bh)`` as described above.
        """
        assert 0 <= layer < self.num_layers
        k = self.kv_cache[0, layer]
        v = self.kv_cache[1, layer]
        pt = self.page_table[layer]
        bh = self.bh_seq_lens[layer]
        return k, v, pt, bh