block_table.py 13.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import numpy as np
import torch

7
from vllm.distributed import get_dcp_group, get_pcp_group
8
from vllm.logger import init_logger
9
from vllm.triton_utils import tl, triton
10
from vllm.utils.math_utils import cdiv
11
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
12
from vllm.v1.utils import CpuGpuBuffer
13
from vllm.v1.worker.cp_utils import get_total_cp_world_size
14
15
16
17
18
19
20

logger = init_logger(__name__)


class BlockTable:
    def __init__(
        self,
21
        block_size: int,
22
23
        max_num_reqs: int,
        max_num_blocks_per_req: int,
24
        max_num_batched_tokens: int,
25
26
        pin_memory: bool,
        device: torch.device,
27
        kernel_block_size: int,
28
        cp_kv_cache_interleave_size: int,
29
    ):
30
31
32
33
34
35
36
37
38
39
40
41
        """
        Args:
            block_size: Block size used for KV cache memory allocation
            max_num_reqs: Maximum number of concurrent requests supported.
            max_num_blocks_per_req: Maximum number of blocks per request.
            max_num_batched_tokens: Maximum number of tokens in a batch.
            pin_memory: Whether to pin memory for faster GPU transfers.
            device: Target device for the block table.
            kernel_block_size: The block_size of underlying attention kernel.
                Will be the same as `block_size` if `block_size` is supported
                by the attention kernel.
        """
42
        self.max_num_reqs = max_num_reqs
43
        self.max_num_batched_tokens = max_num_batched_tokens
44
45
46
        self.pin_memory = pin_memory
        self.device = device

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        if kernel_block_size == block_size:
            # Standard case: allocation and computation use same block size
            # No block splitting needed, direct mapping
            self.block_size = block_size
            self.blocks_per_kv_block = 1
            self.use_hybrid_blocks = False
        else:
            # Hybrid case: allocation block size differs from kernel block size
            # Memory blocks are subdivided to match kernel requirements
            # Example: 32-token memory blocks with 16-token kernel blocks
            # → Each memory block corresponds to 2 kernel blocks
            if block_size % kernel_block_size != 0:
                raise ValueError(
                    f"kernel_block_size {kernel_block_size} must divide "
                    f"kv_manager_block_size size {block_size} evenly"
                )

            self.block_size = kernel_block_size
            self.blocks_per_kv_block = block_size // kernel_block_size
            self.use_hybrid_blocks = True

        self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block

70
        self.block_table = self._make_buffer(
71
            self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
72
        )
73
74
        self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)

75
76
77
        self.slot_mapping = self._make_buffer(
            self.max_num_batched_tokens, dtype=torch.int64
        )
78
79
80
81
82
83
84
85

        if self.use_hybrid_blocks:
            self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
                1, -1
            )
        else:
            self._kernel_block_arange = None

86
87
88
89
        try:
            self.pcp_world_size = get_pcp_group().world_size
            self.pcp_rank = get_pcp_group().rank_in_group
        except AssertionError:
90
            # PCP might not be initialized in testing
91
92
            self.pcp_world_size = 1
            self.pcp_rank = 0
93
94
95
96
97
98
99
        try:
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
100
        self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
101

102
103
    def append_row(
        self,
104
        block_ids: list[int],
105
        row_idx: int,
106
    ) -> None:
107
108
        if not block_ids:
            return
109
110

        if self.use_hybrid_blocks:
111
112
113
            block_ids = self.map_to_kernel_blocks(
                np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
            )
114

115
        num_blocks = len(block_ids)
116
117
        start = self.num_blocks_per_row[row_idx]
        self.num_blocks_per_row[row_idx] += num_blocks
118
        self.block_table.np[row_idx, start : start + num_blocks] = block_ids
119

120
    def add_row(self, block_ids: list[int], row_idx: int) -> None:
121
122
        self.num_blocks_per_row[row_idx] = 0
        self.append_row(block_ids, row_idx)
123

124
125
126
127
128
129
    def clear_row(self, row_idx: int) -> None:
        num_blocks = self.num_blocks_per_row[row_idx]
        if num_blocks > 0:
            self.block_table.np[row_idx, :num_blocks] = 0
        self.num_blocks_per_row[row_idx] = 0

130
131
    def move_row(self, src: int, tgt: int) -> None:
        num_blocks = self.num_blocks_per_row[src]
132
133
        block_table_np = self.block_table.np
        block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
134
135
        self.num_blocks_per_row[tgt] = num_blocks

136
    def swap_row(self, src: int, tgt: int) -> None:
137
138
139
        src_tgt, tgt_src = [src, tgt], [tgt, src]
        self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
        self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
140

141
    def compute_slot_mapping(
142
143
144
145
        self,
        num_reqs: int,
        query_start_loc: torch.Tensor,
        positions: torch.Tensor,
146
    ) -> None:
147
        num_tokens = positions.shape[0]
148
149
        total_cp_world_size = self.pcp_world_size * self.dcp_world_size
        total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        _compute_slot_mapping_kernel[(num_reqs + 1,)](
            num_tokens,
            self.max_num_batched_tokens,
            query_start_loc,
            positions,
            self.block_table.gpu,
            self.block_table.gpu.stride(0),
            self.block_size,
            self.slot_mapping.gpu,
            TOTAL_CP_WORLD_SIZE=total_cp_world_size,
            TOTAL_CP_RANK=total_cp_rank,
            CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
            PAD_ID=PAD_SLOT_ID,
            BLOCK_SIZE=1024,
        )
165
166

    def commit_block_table(self, num_reqs: int) -> None:
167
        self.block_table.copy_to_gpu(num_reqs)
168
169

    def clear(self) -> None:
170
171
        self.block_table.gpu.fill_(0)
        self.block_table.cpu.fill_(0)
172

173
174
175
176
177
178
    @staticmethod
    def map_to_kernel_blocks(
        kv_manager_block_ids: np.ndarray,
        blocks_per_kv_block: int,
        kernel_block_arange: np.ndarray,
    ) -> np.ndarray:
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        """Convert kv_manager_block_id IDs to kernel block IDs.

        Example:
            # kv_manager_block_ids: 32 tokens,
            # Kernel block size: 16 tokens
            # blocks_per_kv_block = 2
            >>> kv_manager_block_ids = np.array([0, 1, 2])
            >>> Result: [0, 1, 2, 3, 4, 5]

            # Each kv_manager_block_id maps to 2 kernel block id:
            # kv_manager_block_id 0 → kernel block id [0, 1]
            # kv_manager_block_id 1 → kernel block id [2, 3]
            # kv_manager_block_id 2 → kernel block id [4, 5]
        """
193
        if blocks_per_kv_block == 1:
194
195
196
            return kv_manager_block_ids

        kernel_block_ids = (
197
198
            kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
            + kernel_block_arange
199
200
201
202
        )

        return kernel_block_ids.reshape(-1)

203
    def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
204
        """Returns the device tensor of the block table."""
205
        return self.block_table.gpu[:num_reqs]
206
207
208

    def get_cpu_tensor(self) -> torch.Tensor:
        """Returns the CPU tensor of the block table."""
209
        return self.block_table.cpu
210
211
212

    def get_numpy_array(self) -> np.ndarray:
        """Returns the numpy array of the block table."""
213
214
        return self.block_table.np

215
    def _make_buffer(
216
        self, *size: int | torch.SymInt, dtype: torch.dtype
217
218
219
220
    ) -> CpuGpuBuffer:
        return CpuGpuBuffer(
            *size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
        )
221
222
223
224
225


class MultiGroupBlockTable:
    """The BlockTables for each KV cache group."""

226
227
228
229
230
231
232
233
    def __init__(
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        pin_memory: bool,
        device: torch.device,
        block_sizes: list[int],
234
        kernel_block_sizes: list[int],
235
        max_num_blocks: list[int] | None = None,
236
        cp_kv_cache_interleave_size: int = 1,
237
    ) -> None:
238
239
240
241
242
        if len(kernel_block_sizes) != len(block_sizes):
            raise ValueError(
                f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
                f"must match block_sizes length ({len(block_sizes)})"
            )
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        if max_num_blocks is None:
            # Note(hc): each dcp rank only store
            # (max_model_len//dcp_world_size) tokens in kvcache,
            # so the block_size which used for calc max_num_blocks_per_req
            # must be multiplied by dcp_world_size.
            total_cp_world_size = get_total_cp_world_size()
            max_num_blocks = [
                cdiv(max_model_len, block_size * total_cp_world_size)
                for block_size in block_sizes
            ]

        if len(max_num_blocks) != len(block_sizes):
            raise ValueError(
                f"max_num_blocks length ({len(max_num_blocks)}) "
                f"must match block_sizes length ({len(block_sizes)})"
            )
259

260
        self.block_tables = [
261
            BlockTable(
262
263
                block_size,
                max_num_reqs,
264
                max_num_blocks_per_req,
265
266
267
                max_num_batched_tokens,
                pin_memory,
                device,
268
                kernel_block_size,
269
                cp_kv_cache_interleave_size,
270
            )
271
272
273
            for block_size, kernel_block_size, max_num_blocks_per_req in zip(
                block_sizes, kernel_block_sizes, max_num_blocks
            )
274
275
        ]

276
    def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
277
278
279
        for i, block_table in enumerate(self.block_tables):
            block_table.append_row(block_ids[i], row_idx)

280
    def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
281
282
283
        for i, block_table in enumerate(self.block_tables):
            block_table.add_row(block_ids[i], row_idx)

284
285
286
287
    def clear_row(self, row_idx: int) -> None:
        for block_table in self.block_tables:
            block_table.clear_row(row_idx)

288
289
290
291
292
293
294
295
    def move_row(self, src: int, tgt: int) -> None:
        for block_table in self.block_tables:
            block_table.move_row(src, tgt)

    def swap_row(self, src: int, tgt: int) -> None:
        for block_table in self.block_tables:
            block_table.swap_row(src, tgt)

296
    def compute_slot_mapping(
297
298
299
300
        self,
        num_reqs: int,
        query_start_loc: torch.Tensor,
        positions: torch.Tensor,
301
    ) -> None:
302
        for block_table in self.block_tables:
303
            block_table.compute_slot_mapping(num_reqs, query_start_loc, positions)
304
305
306
307
308

    def commit_block_table(self, num_reqs: int) -> None:
        for block_table in self.block_tables:
            block_table.commit_block_table(num_reqs)

309
310
311
312
313
314
315
    def clear(self) -> None:
        for block_table in self.block_tables:
            block_table.clear()

    def __getitem__(self, idx: int) -> "BlockTable":
        """Returns the BlockTable for the i-th KV cache group."""
        return self.block_tables[idx]
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373


@triton.jit
def _compute_slot_mapping_kernel(
    num_tokens,
    max_num_tokens,
    query_start_loc_ptr,  # [num_reqs + 1], int32
    positions_ptr,  # [num_tokens], int64
    block_table_ptr,  # [max_num_reqs, max_num_blocks_per_req], int32 (flat)
    block_table_stride,  # max_num_blocks_per_req
    block_size,
    slot_mapping_ptr,  # [max_num_tokens], int64
    TOTAL_CP_WORLD_SIZE: tl.constexpr,
    TOTAL_CP_RANK: tl.constexpr,
    CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
    PAD_ID: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    req_idx = tl.program_id(0)

    if req_idx == tl.num_programs(0) - 1:
        # Pad remaining slots for CUDA graph compatibility.
        for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
            offsets = i + tl.arange(0, BLOCK_SIZE)
            tl.store(
                slot_mapping_ptr + offsets,
                PAD_ID,
                mask=offsets < max_num_tokens,
            )
        return

    start_idx = tl.load(query_start_loc_ptr + req_idx).to(tl.int64)
    end_idx = tl.load(query_start_loc_ptr + req_idx + 1).to(tl.int64)

    virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
    row_offset = req_idx * block_table_stride
    for i in range(start_idx, end_idx, BLOCK_SIZE):
        offsets = i + tl.arange(0, BLOCK_SIZE)
        mask = offsets < end_idx
        pos = tl.load(positions_ptr + offsets, mask=mask, other=0)
        block_indices = pos // virtual_block_size
        block_numbers = tl.load(block_table_ptr + row_offset + block_indices).to(
            tl.int64
        )

        virtual_block_offsets = pos - block_indices * virtual_block_size
        is_local = (
            virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
        ) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
        local_block_offsets = (
            virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
        ) * CP_KV_CACHE_INTERLEAVE_SIZE + (
            virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
        )

        slot_ids = block_numbers * block_size + local_block_offsets
        slot_ids = tl.where(is_local, slot_ids, PAD_ID)
        tl.store(slot_mapping_ptr + offsets, slot_ids, mask=mask)