block_table.py 13.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import numpy as np
import torch

7
from vllm.distributed import get_dcp_group, get_pcp_group
8
from vllm.logger import init_logger
9
from vllm.triton_utils import tl, triton
10
from vllm.utils.math_utils import cdiv
11
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
12
from vllm.v1.utils import CpuGpuBuffer
13
from vllm.v1.worker.cp_utils import get_total_cp_world_size
14
15
16
17
18
19
20

logger = init_logger(__name__)


class BlockTable:
    def __init__(
        self,
21
        block_size: int,
22
23
        max_num_reqs: int,
        max_num_blocks_per_req: int,
24
        max_num_batched_tokens: int,
25
26
        pin_memory: bool,
        device: torch.device,
27
        kernel_block_size: int,
28
        cp_kv_cache_interleave_size: int,
29
    ):
30
31
32
33
34
35
36
37
38
39
40
41
        """
        Args:
            block_size: Block size used for KV cache memory allocation
            max_num_reqs: Maximum number of concurrent requests supported.
            max_num_blocks_per_req: Maximum number of blocks per request.
            max_num_batched_tokens: Maximum number of tokens in a batch.
            pin_memory: Whether to pin memory for faster GPU transfers.
            device: Target device for the block table.
            kernel_block_size: The block_size of underlying attention kernel.
                Will be the same as `block_size` if `block_size` is supported
                by the attention kernel.
        """
42
        self.max_num_reqs = max_num_reqs
43
        self.max_num_batched_tokens = max_num_batched_tokens
44
45
46
        self.pin_memory = pin_memory
        self.device = device

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        if kernel_block_size == block_size:
            # Standard case: allocation and computation use same block size
            # No block splitting needed, direct mapping
            self.block_size = block_size
            self.blocks_per_kv_block = 1
            self.use_hybrid_blocks = False
        else:
            # Hybrid case: allocation block size differs from kernel block size
            # Memory blocks are subdivided to match kernel requirements
            # Example: 32-token memory blocks with 16-token kernel blocks
            # → Each memory block corresponds to 2 kernel blocks
            if block_size % kernel_block_size != 0:
                raise ValueError(
                    f"kernel_block_size {kernel_block_size} must divide "
                    f"kv_manager_block_size size {block_size} evenly"
                )

            self.block_size = kernel_block_size
            self.blocks_per_kv_block = block_size // kernel_block_size
            self.use_hybrid_blocks = True

        self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block

70
        self.block_table = self._make_buffer(
71
            self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
72
        )
73
74
        self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)

75
76
77
        self.slot_mapping = self._make_buffer(
            self.max_num_batched_tokens, dtype=torch.int64
        )
78
79
80
81
82
83
84
85

        if self.use_hybrid_blocks:
            self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
                1, -1
            )
        else:
            self._kernel_block_arange = None

86
87
88
89
        try:
            self.pcp_world_size = get_pcp_group().world_size
            self.pcp_rank = get_pcp_group().rank_in_group
        except AssertionError:
90
            # PCP might not be initialized in testing
91
92
            self.pcp_world_size = 1
            self.pcp_rank = 0
93
94
95
96
97
98
99
        try:
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
100
        self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
101

102
103
    def append_row(
        self,
104
        block_ids: list[int],
105
        row_idx: int,
106
    ) -> None:
107
108
        if not block_ids:
            return
109
110

        if self.use_hybrid_blocks:
111
112
113
            block_ids = self.map_to_kernel_blocks(
                np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
            )
114

115
        num_blocks = len(block_ids)
116
117
        start = self.num_blocks_per_row[row_idx]
        self.num_blocks_per_row[row_idx] += num_blocks
118
        self.block_table.np[row_idx, start : start + num_blocks] = block_ids
119

120
    def add_row(self, block_ids: list[int], row_idx: int) -> None:
121
122
        self.num_blocks_per_row[row_idx] = 0
        self.append_row(block_ids, row_idx)
123
124
125

    def move_row(self, src: int, tgt: int) -> None:
        num_blocks = self.num_blocks_per_row[src]
126
127
        block_table_np = self.block_table.np
        block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
128
129
        self.num_blocks_per_row[tgt] = num_blocks

130
    def swap_row(self, src: int, tgt: int) -> None:
131
132
133
        src_tgt, tgt_src = [src, tgt], [tgt, src]
        self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
        self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
134

135
    def compute_slot_mapping(
136
137
138
139
        self,
        num_reqs: int,
        query_start_loc: torch.Tensor,
        positions: torch.Tensor,
140
    ) -> None:
141
        num_tokens = positions.shape[0]
142
143
        total_cp_world_size = self.pcp_world_size * self.dcp_world_size
        total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        _compute_slot_mapping_kernel[(num_reqs + 1,)](
            num_tokens,
            self.max_num_batched_tokens,
            query_start_loc,
            positions,
            self.block_table.gpu,
            self.block_table.gpu.stride(0),
            self.block_size,
            self.slot_mapping.gpu,
            TOTAL_CP_WORLD_SIZE=total_cp_world_size,
            TOTAL_CP_RANK=total_cp_rank,
            CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
            PAD_ID=PAD_SLOT_ID,
            BLOCK_SIZE=1024,
        )
159
160

    def commit_block_table(self, num_reqs: int) -> None:
161
        self.block_table.copy_to_gpu(num_reqs)
162
163

    def clear(self) -> None:
164
165
        self.block_table.gpu.fill_(0)
        self.block_table.cpu.fill_(0)
166

167
168
169
170
171
172
    @staticmethod
    def map_to_kernel_blocks(
        kv_manager_block_ids: np.ndarray,
        blocks_per_kv_block: int,
        kernel_block_arange: np.ndarray,
    ) -> np.ndarray:
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        """Convert kv_manager_block_id IDs to kernel block IDs.

        Example:
            # kv_manager_block_ids: 32 tokens,
            # Kernel block size: 16 tokens
            # blocks_per_kv_block = 2
            >>> kv_manager_block_ids = np.array([0, 1, 2])
            >>> Result: [0, 1, 2, 3, 4, 5]

            # Each kv_manager_block_id maps to 2 kernel block id:
            # kv_manager_block_id 0 → kernel block id [0, 1]
            # kv_manager_block_id 1 → kernel block id [2, 3]
            # kv_manager_block_id 2 → kernel block id [4, 5]
        """
187
        if blocks_per_kv_block == 1:
188
189
190
            return kv_manager_block_ids

        kernel_block_ids = (
191
192
            kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
            + kernel_block_arange
193
194
195
196
        )

        return kernel_block_ids.reshape(-1)

197
    def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
198
        """Returns the device tensor of the block table."""
199
        return self.block_table.gpu[:num_reqs]
200
201
202

    def get_cpu_tensor(self) -> torch.Tensor:
        """Returns the CPU tensor of the block table."""
203
        return self.block_table.cpu
204
205
206

    def get_numpy_array(self) -> np.ndarray:
        """Returns the numpy array of the block table."""
207
208
        return self.block_table.np

209
    def _make_buffer(
210
        self, *size: int | torch.SymInt, dtype: torch.dtype
211
212
213
214
    ) -> CpuGpuBuffer:
        return CpuGpuBuffer(
            *size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
        )
215
216
217
218
219


class MultiGroupBlockTable:
    """The BlockTables for each KV cache group."""

220
221
222
223
224
225
226
227
    def __init__(
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        pin_memory: bool,
        device: torch.device,
        block_sizes: list[int],
228
        kernel_block_sizes: list[int],
229
        max_num_blocks: list[int] | None = None,
230
        cp_kv_cache_interleave_size: int = 1,
231
    ) -> None:
232
233
234
235
236
        if len(kernel_block_sizes) != len(block_sizes):
            raise ValueError(
                f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
                f"must match block_sizes length ({len(block_sizes)})"
            )
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        if max_num_blocks is None:
            # Note(hc): each dcp rank only store
            # (max_model_len//dcp_world_size) tokens in kvcache,
            # so the block_size which used for calc max_num_blocks_per_req
            # must be multiplied by dcp_world_size.
            total_cp_world_size = get_total_cp_world_size()
            max_num_blocks = [
                cdiv(max_model_len, block_size * total_cp_world_size)
                for block_size in block_sizes
            ]

        if len(max_num_blocks) != len(block_sizes):
            raise ValueError(
                f"max_num_blocks length ({len(max_num_blocks)}) "
                f"must match block_sizes length ({len(block_sizes)})"
            )
253

254
        self.block_tables = [
255
            BlockTable(
256
257
                block_size,
                max_num_reqs,
258
                max_num_blocks_per_req,
259
260
261
                max_num_batched_tokens,
                pin_memory,
                device,
262
                kernel_block_size,
263
                cp_kv_cache_interleave_size,
264
            )
265
266
267
            for block_size, kernel_block_size, max_num_blocks_per_req in zip(
                block_sizes, kernel_block_sizes, max_num_blocks
            )
268
269
        ]

270
    def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
271
272
273
        for i, block_table in enumerate(self.block_tables):
            block_table.append_row(block_ids[i], row_idx)

274
    def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
275
276
277
278
279
280
281
282
283
284
285
        for i, block_table in enumerate(self.block_tables):
            block_table.add_row(block_ids[i], row_idx)

    def move_row(self, src: int, tgt: int) -> None:
        for block_table in self.block_tables:
            block_table.move_row(src, tgt)

    def swap_row(self, src: int, tgt: int) -> None:
        for block_table in self.block_tables:
            block_table.swap_row(src, tgt)

286
    def compute_slot_mapping(
287
288
289
290
        self,
        num_reqs: int,
        query_start_loc: torch.Tensor,
        positions: torch.Tensor,
291
    ) -> None:
292
        for block_table in self.block_tables:
293
            block_table.compute_slot_mapping(num_reqs, query_start_loc, positions)
294
295
296
297
298

    def commit_block_table(self, num_reqs: int) -> None:
        for block_table in self.block_tables:
            block_table.commit_block_table(num_reqs)

299
300
301
302
303
304
305
    def clear(self) -> None:
        for block_table in self.block_tables:
            block_table.clear()

    def __getitem__(self, idx: int) -> "BlockTable":
        """Returns the BlockTable for the i-th KV cache group."""
        return self.block_tables[idx]
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363


@triton.jit
def _compute_slot_mapping_kernel(
    num_tokens,
    max_num_tokens,
    query_start_loc_ptr,  # [num_reqs + 1], int32
    positions_ptr,  # [num_tokens], int64
    block_table_ptr,  # [max_num_reqs, max_num_blocks_per_req], int32 (flat)
    block_table_stride,  # max_num_blocks_per_req
    block_size,
    slot_mapping_ptr,  # [max_num_tokens], int64
    TOTAL_CP_WORLD_SIZE: tl.constexpr,
    TOTAL_CP_RANK: tl.constexpr,
    CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
    PAD_ID: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    req_idx = tl.program_id(0)

    if req_idx == tl.num_programs(0) - 1:
        # Pad remaining slots for CUDA graph compatibility.
        for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
            offsets = i + tl.arange(0, BLOCK_SIZE)
            tl.store(
                slot_mapping_ptr + offsets,
                PAD_ID,
                mask=offsets < max_num_tokens,
            )
        return

    start_idx = tl.load(query_start_loc_ptr + req_idx).to(tl.int64)
    end_idx = tl.load(query_start_loc_ptr + req_idx + 1).to(tl.int64)

    virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
    row_offset = req_idx * block_table_stride
    for i in range(start_idx, end_idx, BLOCK_SIZE):
        offsets = i + tl.arange(0, BLOCK_SIZE)
        mask = offsets < end_idx
        pos = tl.load(positions_ptr + offsets, mask=mask, other=0)
        block_indices = pos // virtual_block_size
        block_numbers = tl.load(block_table_ptr + row_offset + block_indices).to(
            tl.int64
        )

        virtual_block_offsets = pos - block_indices * virtual_block_size
        is_local = (
            virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
        ) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
        local_block_offsets = (
            virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
        ) * CP_KV_CACHE_INTERLEAVE_SIZE + (
            virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
        )

        slot_ids = block_numbers * block_size + local_block_offsets
        slot_ids = tl.where(is_local, slot_ids, PAD_ID)
        tl.store(slot_mapping_ptr + offsets, slot_ids, mask=mask)