block_table.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import numpy as np
import torch

7
from vllm.distributed import get_dcp_group
8
from vllm.logger import init_logger
9
from vllm.utils.math_utils import cdiv
10
from vllm.v1.utils import CpuGpuBuffer
11
12
13
14
15
16
17

logger = init_logger(__name__)


class BlockTable:
    def __init__(
        self,
18
        block_size: int,
19
20
        max_num_reqs: int,
        max_num_blocks_per_req: int,
21
        max_num_batched_tokens: int,
22
23
        pin_memory: bool,
        device: torch.device,
24
        kernel_block_size: int,
25
        dcp_kv_cache_interleave_size: int,
26
    ):
27
28
29
30
31
32
33
34
35
36
37
38
        """
        Args:
            block_size: Block size used for KV cache memory allocation
            max_num_reqs: Maximum number of concurrent requests supported.
            max_num_blocks_per_req: Maximum number of blocks per request.
            max_num_batched_tokens: Maximum number of tokens in a batch.
            pin_memory: Whether to pin memory for faster GPU transfers.
            device: Target device for the block table.
            kernel_block_size: The block_size of underlying attention kernel.
                Will be the same as `block_size` if `block_size` is supported
                by the attention kernel.
        """
39
        self.max_num_reqs = max_num_reqs
40
        self.max_num_batched_tokens = max_num_batched_tokens
41
42
43
        self.pin_memory = pin_memory
        self.device = device

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        if kernel_block_size == block_size:
            # Standard case: allocation and computation use same block size
            # No block splitting needed, direct mapping
            self.block_size = block_size
            self.blocks_per_kv_block = 1
            self.use_hybrid_blocks = False
        else:
            # Hybrid case: allocation block size differs from kernel block size
            # Memory blocks are subdivided to match kernel requirements
            # Example: 32-token memory blocks with 16-token kernel blocks
            # → Each memory block corresponds to 2 kernel blocks
            if block_size % kernel_block_size != 0:
                raise ValueError(
                    f"kernel_block_size {kernel_block_size} must divide "
                    f"kv_manager_block_size size {block_size} evenly"
                )

            self.block_size = kernel_block_size
            self.blocks_per_kv_block = block_size // kernel_block_size
            self.use_hybrid_blocks = True

        self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block

67
        self.block_table = self._make_buffer(
68
            self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
69
        )
70
71
        self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)

72
73
74
        self.slot_mapping = self._make_buffer(
            self.max_num_batched_tokens, dtype=torch.int64
        )
75
76
77
78
79
80
81
82

        if self.use_hybrid_blocks:
            self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
                1, -1
            )
        else:
            self._kernel_block_arange = None

83
84
85
86
87
88
89
        try:
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
90
        self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
91

92
93
    def append_row(
        self,
94
        block_ids: list[int],
95
        row_idx: int,
96
    ) -> None:
97
98
        if not block_ids:
            return
99
100
101
102

        if self.use_hybrid_blocks:
            block_ids = self._map_to_kernel_blocks(np.array(block_ids))

103
        num_blocks = len(block_ids)
104
105
        start = self.num_blocks_per_row[row_idx]
        self.num_blocks_per_row[row_idx] += num_blocks
106
        self.block_table.np[row_idx, start : start + num_blocks] = block_ids
107

108
    def add_row(self, block_ids: list[int], row_idx: int) -> None:
109
110
        self.num_blocks_per_row[row_idx] = 0
        self.append_row(block_ids, row_idx)
111
112
113

    def move_row(self, src: int, tgt: int) -> None:
        num_blocks = self.num_blocks_per_row[src]
114
115
        block_table_np = self.block_table.np
        block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
116
117
        self.num_blocks_per_row[tgt] = num_blocks

118
    def swap_row(self, src: int, tgt: int) -> None:
119
120
121
        src_tgt, tgt_src = [src, tgt], [tgt, src]
        self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
        self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
122

123
124
125
    def compute_slot_mapping(
        self, req_indices: np.ndarray, positions: np.ndarray
    ) -> None:
126
127
128
129
130
131
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size`
        # here because M (max_model_len) is not necessarily divisible by
        # block_size.
132
        if self.dcp_world_size > 1:
133
            # Note(hc): The DCP implement store kvcache with an interleave
134
135
136
137
138
139
            # style, the kvcache for the token whose token_idx is i is
            # always stored on the GPU whose dcp_rank equals i % cp_world_size:

            # Use a "virtual block" which equals to world_size * block_size
            # for block_table_indices calculation.
            virtual_block_size = self.block_size * self.dcp_world_size
140
141
142
143
            block_table_indices = (
                req_indices * self.max_num_blocks_per_req
                + positions // virtual_block_size
            )
144

145
            block_numbers = self.block_table.np.ravel()[block_table_indices]
146
147
148
            # Use virtual_block_size for mask calculation, which marks local
            # tokens.
            virtual_block_offsets = positions % virtual_block_size
149
150
151
152
153
154
            mask = (
                virtual_block_offsets
                // self.dcp_kv_cache_interleave_size
                % self.dcp_world_size
                == self.dcp_rank
            )
co63oc's avatar
co63oc committed
155
            # Calculate local block_offsets
156
157
158
159
160
161
            block_offsets = (
                virtual_block_offsets
                // (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
                * self.dcp_kv_cache_interleave_size
                + virtual_block_offsets % self.dcp_kv_cache_interleave_size
            )
co63oc's avatar
co63oc committed
162
            # Calculate slot_mapping
163
164
            slot_mapping = block_numbers * self.block_size + block_offsets
            # Write final slots, use -1 for not-local
165
166
167
            self.slot_mapping.np[: req_indices.shape[0]] = np.where(
                mask, slot_mapping, -1
            )
168
        else:
169
170
171
            block_table_indices = (
                req_indices * self.max_num_blocks_per_req + positions // self.block_size
            )
172

173
            block_numbers = self.block_table.np.ravel()[block_table_indices]
174
            block_offsets = positions % self.block_size
175
176
177
178
179
            np.add(
                block_numbers * self.block_size,
                block_offsets,
                out=self.slot_mapping.np[: req_indices.shape[0]],
            )
180
181

    def commit_block_table(self, num_reqs: int) -> None:
182
        self.block_table.copy_to_gpu(num_reqs)
183

184
    def commit_slot_mapping(self, num_tokens: int) -> None:
185
        self.slot_mapping.copy_to_gpu(num_tokens)
186

187
    def clear(self) -> None:
188
189
        self.block_table.gpu.fill_(0)
        self.block_table.cpu.fill_(0)
190

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
        """Convert kv_manager_block_id IDs to kernel block IDs.

        Example:
            # kv_manager_block_ids: 32 tokens,
            # Kernel block size: 16 tokens
            # blocks_per_kv_block = 2
            >>> kv_manager_block_ids = np.array([0, 1, 2])
            >>> Result: [0, 1, 2, 3, 4, 5]

            # Each kv_manager_block_id maps to 2 kernel block id:
            # kv_manager_block_id 0 → kernel block id [0, 1]
            # kv_manager_block_id 1 → kernel block id [2, 3]
            # kv_manager_block_id 2 → kernel block id [4, 5]
        """
        if not self.use_hybrid_blocks:
            return kv_manager_block_ids

        kernel_block_ids = (
            kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
            + self._kernel_block_arange
        )

        return kernel_block_ids.reshape(-1)

216
    def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
217
        """Returns the device tensor of the block table."""
218
        return self.block_table.gpu[:num_reqs]
219
220
221

    def get_cpu_tensor(self) -> torch.Tensor:
        """Returns the CPU tensor of the block table."""
222
        return self.block_table.cpu
223
224
225

    def get_numpy_array(self) -> np.ndarray:
        """Returns the numpy array of the block table."""
226
227
        return self.block_table.np

228
    def _make_buffer(
229
        self, *size: int | torch.SymInt, dtype: torch.dtype
230
231
232
233
    ) -> CpuGpuBuffer:
        return CpuGpuBuffer(
            *size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
        )
234
235
236
237
238


class MultiGroupBlockTable:
    """The BlockTables for each KV cache group."""

239
240
241
242
243
244
245
246
    def __init__(
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        pin_memory: bool,
        device: torch.device,
        block_sizes: list[int],
247
        kernel_block_sizes: list[int],
248
        num_speculative_tokens: int = 0,
249
        dcp_kv_cache_interleave_size: int = 1,
250
    ) -> None:
251
252
253
254
255
256
257
258
259
260
        # Note(hc): each dcp rank only store
        # (max_model_len//dcp_world_size) tokens in kvcache,
        # so the block_size which used for calc max_num_blocks_per_req
        # must be multiplied by dcp_world_size.
        try:
            dcp_world_size = get_dcp_group().world_size
        except AssertionError:
            # DCP might not be initialized in testing
            dcp_world_size = 1

261
262
263
264
265
266
        if len(kernel_block_sizes) != len(block_sizes):
            raise ValueError(
                f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
                f"must match block_sizes length ({len(block_sizes)})"
            )

267
        self.block_tables = [
268
            BlockTable(
269
270
271
272
273
274
275
276
277
                block_size,
                max_num_reqs,
                max(
                    cdiv(max_model_len, block_size * dcp_world_size),
                    1 + num_speculative_tokens,
                ),
                max_num_batched_tokens,
                pin_memory,
                device,
278
                kernel_block_size,
279
                dcp_kv_cache_interleave_size,
280
            )
281
            for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
282
283
        ]

284
    def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
285
286
287
        for i, block_table in enumerate(self.block_tables):
            block_table.append_row(block_ids[i], row_idx)

288
    def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
289
290
291
292
293
294
295
296
297
298
299
        for i, block_table in enumerate(self.block_tables):
            block_table.add_row(block_ids[i], row_idx)

    def move_row(self, src: int, tgt: int) -> None:
        for block_table in self.block_tables:
            block_table.move_row(src, tgt)

    def swap_row(self, src: int, tgt: int) -> None:
        for block_table in self.block_tables:
            block_table.swap_row(src, tgt)

300
301
302
    def compute_slot_mapping(
        self, req_indices: np.ndarray, positions: np.ndarray
    ) -> None:
303
304
305
306
307
308
309
310
        for block_table in self.block_tables:
            block_table.compute_slot_mapping(req_indices, positions)

    def commit_block_table(self, num_reqs: int) -> None:
        for block_table in self.block_tables:
            block_table.commit_block_table(num_reqs)

    def commit_slot_mapping(self, num_tokens: int) -> None:
311
        for block_table in self.block_tables:
312
            block_table.commit_slot_mapping(num_tokens)
313
314
315
316
317
318
319
320

    def clear(self) -> None:
        for block_table in self.block_tables:
            block_table.clear()

    def __getitem__(self, idx: int) -> "BlockTable":
        """Returns the BlockTable for the i-th KV cache group."""
        return self.block_tables[idx]