"vllm/entrypoints/serve/lora/protocol.py" did not exist on "252c011012658490c2e970c5bc739a84bf5759c0"
block_table.py 10.2 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import torch

7
from vllm.triton_utils import tl, triton
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.utils.math_utils import cdiv
9
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
10
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
15
16
17
18
19
20


class BlockTables:
    def __init__(
        self,
        block_sizes: list[int],
        max_num_reqs: int,
        max_num_batched_tokens: int,
        max_model_len: int,
        device: torch.device,
21
22
23
        cp_size: int = 1,
        cp_rank: int = 0,
        cp_interleave: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27
28
29
    ):
        self.block_sizes = block_sizes
        self.max_num_reqs = max_num_reqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.max_model_len = max_model_len
        self.device = device
30

31
32
33
        self.cp_size = cp_size
        self.cp_rank = cp_rank
        self.cp_interleave = cp_interleave
34

Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
        self.num_kv_cache_groups = len(self.block_sizes)
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
37
        self.block_tables: list[StagedWriteTensor] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
        for i in range(self.num_kv_cache_groups):
            block_size = self.block_sizes[i]
40
41
42
43
            # When using DCP, each request's KV cache is sharded among different ranks.
            # As a result, one block on the current rank covers `block_size * cp_size`
            # tokens in the full, global (unsharded) sequence.
            max_num_blocks = cdiv(self.max_model_len, block_size * self.cp_size)
44
45
            block_table = StagedWriteTensor(
                (self.max_num_reqs, max_num_blocks),
Woosuk Kwon's avatar
Woosuk Kwon committed
46
                dtype=torch.int32,
47
                device=device,
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
            )
            self.block_tables.append(block_table)
50
51
52
        self.block_table_ptrs = self._make_ptr_tensor(
            [b.gpu for b in self.block_tables]
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
53
        self.block_table_strides = torch.tensor(
54
            [b.gpu.stride(0) for b in self.block_tables],
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
            dtype=torch.int64,
            device=self.device,
        )
58

Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
        self.block_sizes_tensor = torch.tensor(
            self.block_sizes, dtype=torch.int32, device=self.device
        )
62
63
        self.num_blocks = UvaBackedTensor(
            (self.num_kv_cache_groups, self.max_num_reqs),
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
            dtype=torch.int32,
        )
66
67
68
69
70
71
72
73

        # Block tables used for model's forward pass.
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
        self.input_block_tables: list[torch.Tensor] = [
            torch.zeros_like(b.gpu) for b in self.block_tables
        ]
        self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)

Woosuk Kwon's avatar
Woosuk Kwon committed
74
75
76
77
78
79
80
81
82
        self.slot_mappings = torch.zeros(
            self.num_kv_cache_groups,
            self.max_num_batched_tokens,
            dtype=torch.int64,
            device=self.device,
        )

    def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
        # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
83
        return torch.tensor(
84
            [t.data_ptr() for t in x], dtype=torch.uint64, device=self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
        )

    def append_block_ids(
        self,
89
        req_index: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
90
        new_block_ids: tuple[list[int], ...],
91
        overwrite: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
    ) -> None:
        for i in range(self.num_kv_cache_groups):
94
            start = self.num_blocks.np[i, req_index] if not overwrite else 0
95
            block_ids = new_block_ids[i]
96
97
            self.block_tables[i].stage_write(req_index, start, block_ids)
            self.num_blocks.np[i, req_index] = start + len(block_ids)
98

99
100
101
102
103
104
    def apply_staged_writes(self) -> None:
        # TODO(woosuk): This can be inefficient since it launches one kernel per
        # block table. Implement a kernel to handle all block tables at once.
        for block_table in self.block_tables:
            block_table.apply_write()
        self.num_blocks.copy_to_uva()
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106

    def gather_block_tables(
107
        self, idx_mapping: torch.Tensor
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
111
112
113
114
    ) -> tuple[torch.Tensor, ...]:
        num_reqs = idx_mapping.shape[0]
        _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
            idx_mapping,
            self.block_table_ptrs,
            self.input_block_table_ptrs,
            self.block_table_strides,
115
116
            self.num_blocks.gpu,
            self.num_blocks.gpu.stride(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
119
120
121
            BLOCK_SIZE=1024,  # type: ignore
        )
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
122
123
124
125
        # NOTE(woosuk): The output may be used for CUDA graph capture.
        # Therefore, this method must return the persistent tensor
        # with the same memory address as that used during the model's forward pass,
        # rather than allocating a new tensor.
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
129
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def compute_slot_mappings(
        self,
130
        idx_mapping: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
        query_start_loc: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
134
        num_reqs = idx_mapping.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
137
138
139
        num_tokens = positions.shape[0]
        num_groups = self.num_kv_cache_groups
        _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
            num_tokens,
            self.max_num_batched_tokens,
140
            idx_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
            query_start_loc,
            positions,
143
            self.block_table_ptrs,
Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
146
147
            self.block_table_strides,
            self.block_sizes_tensor,
            self.slot_mappings,
            self.slot_mappings.stride(0),
148
149
150
            self.cp_rank,
            CP_SIZE=self.cp_size,
            CP_INTERLEAVE=self.cp_interleave,
Woosuk Kwon's avatar
Woosuk Kwon committed
151
            PAD_ID=PAD_SLOT_ID,
152
            TRITON_BLOCK_SIZE=1024,  # type: ignore
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155
156
        )
        return self.slot_mappings[:, :num_tokens]

    def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
157
158
159
        # Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
        # This is because the padding logic is complex and kernels may access beyond
        # the requested range.
Woosuk Kwon's avatar
Woosuk Kwon committed
160
        self.slot_mappings.fill_(PAD_SLOT_ID)
161
162
163
164
        # NOTE(woosuk): The output may be used for CUDA graph capture.
        # Therefore, this method must return the persistent tensor
        # with the same memory address as that used during the model's forward pass,
        # rather than allocating a new tensor.
Woosuk Kwon's avatar
Woosuk Kwon committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        return self.slot_mappings[:, :num_tokens]


@triton.jit
def _gather_block_tables_kernel(
    batch_idx_to_req_idx,  # [batch_size]
    src_block_table_ptrs,  # [num_kv_cache_groups]
    dst_block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
    num_blocks_ptr,  # [num_kv_cache_groups, max_num_reqs]
    num_blocks_stride,
    BLOCK_SIZE: tl.constexpr,
):
    # kv cache group id
    group_id = tl.program_id(0)
    batch_idx = tl.program_id(1)
    req_idx = tl.load(batch_idx_to_req_idx + batch_idx)

    group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
    num_blocks = tl.load(group_num_blocks_ptr + req_idx)

    stride = tl.load(block_table_strides + group_id)
    src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
    src_row_ptr = src_block_table_ptr + req_idx * stride
    dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
    dst_row_ptr = dst_block_table_ptr + batch_idx * stride

    for i in tl.range(0, num_blocks, BLOCK_SIZE):
        offset = i + tl.arange(0, BLOCK_SIZE)
        block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
        tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)


@triton.jit
def _compute_slot_mappings_kernel(
    num_tokens,
    max_num_tokens,
202
203
    idx_mapping,  # [num_reqs]
    query_start_loc,  # [num_reqs + 1]
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
206
    pos,  # [num_tokens]
    block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
207
    block_sizes,  # [num_kv_cache_groups]
Woosuk Kwon's avatar
Woosuk Kwon committed
208
209
    slot_mappings_ptr,  # [num_kv_cache_groups, max_num_tokens]
    slot_mappings_stride,
210
211
212
    cp_rank,
    CP_SIZE: tl.constexpr,
    CP_INTERLEAVE: tl.constexpr,
Woosuk Kwon's avatar
Woosuk Kwon committed
213
    PAD_ID: tl.constexpr,
214
    TRITON_BLOCK_SIZE: tl.constexpr,
Woosuk Kwon's avatar
Woosuk Kwon committed
215
216
217
):
    # kv cache group id
    group_id = tl.program_id(0)
218
    batch_idx = tl.program_id(1)
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
    slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride

221
    if batch_idx == tl.num_programs(1) - 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
222
        # Pad remaining slots to -1. This is needed for CUDA graphs.
223
224
        for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
            offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
227
228
229
            tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
        return

    block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
    block_table_stride = tl.load(block_table_strides + group_id)
230
    block_size = tl.load(block_sizes + group_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
231

232
233
234
235
236
    req_state_idx = tl.load(idx_mapping + batch_idx)
    start_idx = tl.load(query_start_loc + batch_idx)
    end_idx = tl.load(query_start_loc + batch_idx + 1)
    for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
        offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
237
        positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
238
239
240

        block_indices = positions // (block_size * CP_SIZE)
        block_offsets = positions % (block_size * CP_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
241
        block_numbers = tl.load(
242
            block_table_ptr + req_state_idx * block_table_stride + block_indices
Woosuk Kwon's avatar
Woosuk Kwon committed
243
        )
244

245
246
247
248
249
250
251
252
253
254
255
        if CP_SIZE == 1:
            # Common case: Context parallelism is not used.
            slot_ids = block_numbers * block_size + block_offsets
        else:
            # Context parallelism is used.
            is_local = block_offsets // CP_INTERLEAVE % CP_SIZE == cp_rank
            rounds = block_offsets // (CP_INTERLEAVE * CP_SIZE)
            remainder = block_offsets % CP_INTERLEAVE
            local_offsets = rounds * CP_INTERLEAVE + remainder
            slot_ids = block_numbers * block_size + local_offsets
            slot_ids = tl.where(is_local, slot_ids, PAD_ID)
256

Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
260
261
262
263
264
        tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)


@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
    ptr = tl.load(ptr_to_ptr)
    ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
    return tl.multiple_of(ptr, 16)