block_table.py 11 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import torch

7
from vllm.triton_utils import tl, triton
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.utils.math_utils import cdiv
9
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
10
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
15
16
17
18
19
20


class BlockTables:
    def __init__(
        self,
        block_sizes: list[int],
        max_num_reqs: int,
        max_num_batched_tokens: int,
        max_model_len: int,
        device: torch.device,
21
22
23
        cp_size: int = 1,
        cp_rank: int = 0,
        cp_interleave: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27
28
29
    ):
        self.block_sizes = block_sizes
        self.max_num_reqs = max_num_reqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.max_model_len = max_model_len
        self.device = device
30

31
32
33
        self.cp_size = cp_size
        self.cp_rank = cp_rank
        self.cp_interleave = cp_interleave
34

Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
        self.num_kv_cache_groups = len(self.block_sizes)
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
37
        self.block_tables: list[StagedWriteTensor] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
        for i in range(self.num_kv_cache_groups):
            block_size = self.block_sizes[i]
40
41
42
43
            # When using DCP, each request's KV cache is sharded among different ranks.
            # As a result, one block on the current rank covers `block_size * cp_size`
            # tokens in the full, global (unsharded) sequence.
            max_num_blocks = cdiv(self.max_model_len, block_size * self.cp_size)
44
45
            block_table = StagedWriteTensor(
                (self.max_num_reqs, max_num_blocks),
Woosuk Kwon's avatar
Woosuk Kwon committed
46
                dtype=torch.int32,
47
                device=device,
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
            )
            self.block_tables.append(block_table)
50
51
52
        self.block_table_ptrs = self._make_ptr_tensor(
            [b.gpu for b in self.block_tables]
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
53
        self.block_table_strides = torch.tensor(
54
            [b.gpu.stride(0) for b in self.block_tables],
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
            dtype=torch.int64,
            device=self.device,
        )
58

Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
        self.block_sizes_tensor = torch.tensor(
            self.block_sizes, dtype=torch.int32, device=self.device
        )
62
63
        self.num_blocks = UvaBackedTensor(
            (self.num_kv_cache_groups, self.max_num_reqs),
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
            dtype=torch.int32,
        )
66
67
68
69
70
71
72
73

        # Block tables used for model's forward pass.
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
        self.input_block_tables: list[torch.Tensor] = [
            torch.zeros_like(b.gpu) for b in self.block_tables
        ]
        self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)

Woosuk Kwon's avatar
Woosuk Kwon committed
74
75
76
77
78
79
80
81
82
        self.slot_mappings = torch.zeros(
            self.num_kv_cache_groups,
            self.max_num_batched_tokens,
            dtype=torch.int64,
            device=self.device,
        )

    def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
        # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
83
        return torch.tensor(
84
            [t.data_ptr() for t in x], dtype=torch.uint64, device=self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
        )

    def append_block_ids(
        self,
89
        req_index: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
90
        new_block_ids: tuple[list[int], ...],
91
        overwrite: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
    ) -> None:
        for i in range(self.num_kv_cache_groups):
94
            start = self.num_blocks.np[i, req_index] if not overwrite else 0
95
            block_ids = new_block_ids[i]
96
97
            self.block_tables[i].stage_write(req_index, start, block_ids)
            self.num_blocks.np[i, req_index] = start + len(block_ids)
98

99
100
101
102
103
104
    def apply_staged_writes(self) -> None:
        # TODO(woosuk): This can be inefficient since it launches one kernel per
        # block table. Implement a kernel to handle all block tables at once.
        for block_table in self.block_tables:
            block_table.apply_write()
        self.num_blocks.copy_to_uva()
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106

    def gather_block_tables(
107
108
109
        self,
        idx_mapping: torch.Tensor,
        num_reqs_padded: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
    ) -> tuple[torch.Tensor, ...]:
        num_reqs = idx_mapping.shape[0]
112
113
        # Launch kernel with num_reqs_padded to fuse zeroing of padded rows.
        _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs_padded)](
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
116
117
            idx_mapping,
            self.block_table_ptrs,
            self.input_block_table_ptrs,
            self.block_table_strides,
118
119
            self.num_blocks.gpu,
            self.num_blocks.gpu.stride(0),
120
121
            num_reqs,
            self.input_block_tables[0].shape[1],  # max_num_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
            BLOCK_SIZE=1024,  # type: ignore
        )
124
        return tuple(bt[:num_reqs_padded] for bt in self.input_block_tables)
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126

    def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
127
128
129
130
        # NOTE(woosuk): The output may be used for CUDA graph capture.
        # Therefore, this method must return the persistent tensor
        # with the same memory address as that used during the model's forward pass,
        # rather than allocating a new tensor.
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
134
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def compute_slot_mappings(
        self,
135
        idx_mapping: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
        query_start_loc: torch.Tensor,
        positions: torch.Tensor,
138
        num_tokens_padded: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
139
    ) -> torch.Tensor:
140
        num_reqs = idx_mapping.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
143
        num_groups = self.num_kv_cache_groups
        _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
            self.max_num_batched_tokens,
144
            idx_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
            query_start_loc,
            positions,
147
            self.block_table_ptrs,
Woosuk Kwon's avatar
Woosuk Kwon committed
148
149
150
151
            self.block_table_strides,
            self.block_sizes_tensor,
            self.slot_mappings,
            self.slot_mappings.stride(0),
152
153
154
            self.cp_rank,
            CP_SIZE=self.cp_size,
            CP_INTERLEAVE=self.cp_interleave,
Woosuk Kwon's avatar
Woosuk Kwon committed
155
            PAD_ID=PAD_SLOT_ID,
156
            TRITON_BLOCK_SIZE=1024,  # type: ignore
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        )
158
        return self.slot_mappings[:, :num_tokens_padded]
Woosuk Kwon's avatar
Woosuk Kwon committed
159
160

    def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
161
162
163
        # Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
        # This is because the padding logic is complex and kernels may access beyond
        # the requested range.
Woosuk Kwon's avatar
Woosuk Kwon committed
164
        self.slot_mappings.fill_(PAD_SLOT_ID)
165
166
167
168
        # NOTE(woosuk): The output may be used for CUDA graph capture.
        # Therefore, this method must return the persistent tensor
        # with the same memory address as that used during the model's forward pass,
        # rather than allocating a new tensor.
Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
171
172
173
174
175
176
177
178
179
        return self.slot_mappings[:, :num_tokens]


@triton.jit
def _gather_block_tables_kernel(
    batch_idx_to_req_idx,  # [batch_size]
    src_block_table_ptrs,  # [num_kv_cache_groups]
    dst_block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
    num_blocks_ptr,  # [num_kv_cache_groups, max_num_reqs]
    num_blocks_stride,
180
181
    num_reqs,  # actual number of requests (for padding)
    max_num_blocks,  # stride for zeroing padded rows
Woosuk Kwon's avatar
Woosuk Kwon committed
182
183
184
185
186
187
    BLOCK_SIZE: tl.constexpr,
):
    # kv cache group id
    group_id = tl.program_id(0)
    batch_idx = tl.program_id(1)

188
189
190
191
192
193
194
195
196
197
198
199
    stride = tl.load(block_table_strides + group_id)
    dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
    dst_row_ptr = dst_block_table_ptr + batch_idx * stride

    if batch_idx >= num_reqs:
        # Zero out padded rows.
        for i in tl.range(0, max_num_blocks, BLOCK_SIZE):
            offset = i + tl.arange(0, BLOCK_SIZE)
            tl.store(dst_row_ptr + offset, 0, mask=offset < max_num_blocks)
        return

    req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
    num_blocks = tl.load(group_num_blocks_ptr + req_idx)

    src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
    src_row_ptr = src_block_table_ptr + req_idx * stride

    for i in tl.range(0, num_blocks, BLOCK_SIZE):
        offset = i + tl.arange(0, BLOCK_SIZE)
        block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
        tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)


@triton.jit
def _compute_slot_mappings_kernel(
    max_num_tokens,
215
216
    idx_mapping,  # [num_reqs]
    query_start_loc,  # [num_reqs + 1]
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
219
    pos,  # [num_tokens]
    block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
220
    block_sizes,  # [num_kv_cache_groups]
Woosuk Kwon's avatar
Woosuk Kwon committed
221
222
    slot_mappings_ptr,  # [num_kv_cache_groups, max_num_tokens]
    slot_mappings_stride,
223
224
225
    cp_rank,
    CP_SIZE: tl.constexpr,
    CP_INTERLEAVE: tl.constexpr,
Woosuk Kwon's avatar
Woosuk Kwon committed
226
    PAD_ID: tl.constexpr,
227
    TRITON_BLOCK_SIZE: tl.constexpr,
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
230
):
    # kv cache group id
    group_id = tl.program_id(0)
231
    batch_idx = tl.program_id(1)
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
    slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride

234
    if batch_idx == tl.num_programs(1) - 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
235
        # Pad remaining slots to -1. This is needed for CUDA graphs.
236
237
238
239
240
        # Start from actual token count (not padded) to cover the gap
        # between actual tokens and padded tokens that can contain stale
        # valid slot IDs from previous chunks during chunked prefill.
        actual_num_tokens = tl.load(query_start_loc + batch_idx)
        for i in range(actual_num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
241
            offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
242
243
244
245
246
            tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
        return

    block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
    block_table_stride = tl.load(block_table_strides + group_id)
247
    block_size = tl.load(block_sizes + group_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
248

249
250
251
252
253
    req_state_idx = tl.load(idx_mapping + batch_idx)
    start_idx = tl.load(query_start_loc + batch_idx)
    end_idx = tl.load(query_start_loc + batch_idx + 1)
    for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
        offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
254
        positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
255
256
257

        block_indices = positions // (block_size * CP_SIZE)
        block_offsets = positions % (block_size * CP_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
258
        block_numbers = tl.load(
259
            block_table_ptr + req_state_idx * block_table_stride + block_indices
Woosuk Kwon's avatar
Woosuk Kwon committed
260
        )
261

262
263
264
265
266
267
268
269
270
271
272
        if CP_SIZE == 1:
            # Common case: Context parallelism is not used.
            slot_ids = block_numbers * block_size + block_offsets
        else:
            # Context parallelism is used.
            is_local = block_offsets // CP_INTERLEAVE % CP_SIZE == cp_rank
            rounds = block_offsets // (CP_INTERLEAVE * CP_SIZE)
            remainder = block_offsets % CP_INTERLEAVE
            local_offsets = rounds * CP_INTERLEAVE + remainder
            slot_ids = block_numbers * block_size + local_offsets
            slot_ids = tl.where(is_local, slot_ids, PAD_ID)
273

Woosuk Kwon's avatar
Woosuk Kwon committed
274
275
276
277
278
279
280
281
        tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)


@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
    ptr = tl.load(ptr_to_ptr)
    ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
    return tl.multiple_of(ptr, 16)