block_table.py 9.48 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import torch

7
from vllm.triton_utils import tl, triton
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.utils.math_utils import cdiv
9
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
10
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
15
16
17
18
19
20


class BlockTables:
    def __init__(
        self,
        block_sizes: list[int],
        max_num_reqs: int,
        max_num_batched_tokens: int,
        max_model_len: int,
        device: torch.device,
21
22
23
        cp_size: int = 1,
        cp_rank: int = 0,
        cp_interleave: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27
28
29
    ):
        self.block_sizes = block_sizes
        self.max_num_reqs = max_num_reqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.max_model_len = max_model_len
        self.device = device
30

31
32
33
        self.cp_size = cp_size
        self.cp_rank = cp_rank
        self.cp_interleave = cp_interleave
34

Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
        self.num_kv_cache_groups = len(self.block_sizes)
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
37
        self.block_tables: list[StagedWriteTensor] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
        for i in range(self.num_kv_cache_groups):
            block_size = self.block_sizes[i]
40
41
42
43
            # When using DCP, each request's KV cache is sharded among different ranks.
            # As a result, one block on the current rank covers `block_size * cp_size`
            # tokens in the full, global (unsharded) sequence.
            max_num_blocks = cdiv(self.max_model_len, block_size * self.cp_size)
44
45
            block_table = StagedWriteTensor(
                (self.max_num_reqs, max_num_blocks),
Woosuk Kwon's avatar
Woosuk Kwon committed
46
                dtype=torch.int32,
47
                device=device,
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
            )
            self.block_tables.append(block_table)
50
51
52
        self.block_table_ptrs = self._make_ptr_tensor(
            [b.gpu for b in self.block_tables]
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
53
        self.block_table_strides = torch.tensor(
54
            [b.gpu.stride(0) for b in self.block_tables],
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
            dtype=torch.int64,
            device=self.device,
        )
58

Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
        self.block_sizes_tensor = torch.tensor(
            self.block_sizes, dtype=torch.int32, device=self.device
        )
62
63
        self.num_blocks = UvaBackedTensor(
            (self.num_kv_cache_groups, self.max_num_reqs),
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
            dtype=torch.int32,
        )
66
67
68
69
70
71
72
73

        # Block tables used for model's forward pass.
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
        self.input_block_tables: list[torch.Tensor] = [
            torch.zeros_like(b.gpu) for b in self.block_tables
        ]
        self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)

Woosuk Kwon's avatar
Woosuk Kwon committed
74
75
76
77
78
79
80
81
82
        self.slot_mappings = torch.zeros(
            self.num_kv_cache_groups,
            self.max_num_batched_tokens,
            dtype=torch.int64,
            device=self.device,
        )

    def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
        # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
83
        return torch.tensor(
84
            [t.data_ptr() for t in x], dtype=torch.uint64, device=self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
        )

    def append_block_ids(
        self,
89
        req_index: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
90
        new_block_ids: tuple[list[int], ...],
91
        overwrite: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
    ) -> None:
        for i in range(self.num_kv_cache_groups):
94
            start = self.num_blocks.np[i, req_index] if not overwrite else 0
95
            block_ids = new_block_ids[i]
96
97
            self.block_tables[i].stage_write(req_index, start, block_ids)
            self.num_blocks.np[i, req_index] = start + len(block_ids)
98

99
100
101
102
103
104
    def apply_staged_writes(self) -> None:
        # TODO(woosuk): This can be inefficient since it launches one kernel per
        # block table. Implement a kernel to handle all block tables at once.
        for block_table in self.block_tables:
            block_table.apply_write()
        self.num_blocks.copy_to_uva()
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106

    def gather_block_tables(
107
        self, idx_mapping: torch.Tensor
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
111
112
113
114
    ) -> tuple[torch.Tensor, ...]:
        num_reqs = idx_mapping.shape[0]
        _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
            idx_mapping,
            self.block_table_ptrs,
            self.input_block_table_ptrs,
            self.block_table_strides,
115
116
            self.num_blocks.gpu,
            self.num_blocks.gpu.stride(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
119
120
121
122
123
124
125
            BLOCK_SIZE=1024,  # type: ignore
        )
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def compute_slot_mappings(
        self,
126
        idx_mapping: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
129
        query_start_loc: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
130
        num_reqs = idx_mapping.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
134
135
        num_tokens = positions.shape[0]
        num_groups = self.num_kv_cache_groups
        _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
            num_tokens,
            self.max_num_batched_tokens,
136
            idx_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
            query_start_loc,
            positions,
139
            self.block_table_ptrs,
Woosuk Kwon's avatar
Woosuk Kwon committed
140
141
142
143
            self.block_table_strides,
            self.block_sizes_tensor,
            self.slot_mappings,
            self.slot_mappings.stride(0),
144
145
146
            self.cp_rank,
            CP_SIZE=self.cp_size,
            CP_INTERLEAVE=self.cp_interleave,
Woosuk Kwon's avatar
Woosuk Kwon committed
147
            PAD_ID=PAD_SLOT_ID,
148
            TRITON_BLOCK_SIZE=1024,  # type: ignore
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        )
        return self.slot_mappings[:, :num_tokens]

    def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
        self.slot_mappings.fill_(PAD_SLOT_ID)
        return self.slot_mappings[:, :num_tokens]


@triton.jit
def _gather_block_tables_kernel(
    batch_idx_to_req_idx,  # [batch_size]
    src_block_table_ptrs,  # [num_kv_cache_groups]
    dst_block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
    num_blocks_ptr,  # [num_kv_cache_groups, max_num_reqs]
    num_blocks_stride,
    BLOCK_SIZE: tl.constexpr,
):
    # kv cache group id
    group_id = tl.program_id(0)
    batch_idx = tl.program_id(1)
    req_idx = tl.load(batch_idx_to_req_idx + batch_idx)

    group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
    num_blocks = tl.load(group_num_blocks_ptr + req_idx)

    stride = tl.load(block_table_strides + group_id)
    src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
    src_row_ptr = src_block_table_ptr + req_idx * stride
    dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
    dst_row_ptr = dst_block_table_ptr + batch_idx * stride

    for i in tl.range(0, num_blocks, BLOCK_SIZE):
        offset = i + tl.arange(0, BLOCK_SIZE)
        block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
        tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)


@triton.jit
def _compute_slot_mappings_kernel(
    num_tokens,
    max_num_tokens,
191
192
    idx_mapping,  # [num_reqs]
    query_start_loc,  # [num_reqs + 1]
Woosuk Kwon's avatar
Woosuk Kwon committed
193
194
195
    pos,  # [num_tokens]
    block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
196
    block_sizes,  # [num_kv_cache_groups]
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
    slot_mappings_ptr,  # [num_kv_cache_groups, max_num_tokens]
    slot_mappings_stride,
199
200
201
    cp_rank,
    CP_SIZE: tl.constexpr,
    CP_INTERLEAVE: tl.constexpr,
Woosuk Kwon's avatar
Woosuk Kwon committed
202
    PAD_ID: tl.constexpr,
203
    TRITON_BLOCK_SIZE: tl.constexpr,
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
206
):
    # kv cache group id
    group_id = tl.program_id(0)
207
    batch_idx = tl.program_id(1)
Woosuk Kwon's avatar
Woosuk Kwon committed
208
209
    slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride

210
    if batch_idx == tl.num_programs(1) - 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
211
        # Pad remaining slots to -1. This is needed for CUDA graphs.
212
213
        for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
            offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
216
217
218
            tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
        return

    block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
    block_table_stride = tl.load(block_table_strides + group_id)
219
    block_size = tl.load(block_sizes + group_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
220

221
222
223
224
225
    req_state_idx = tl.load(idx_mapping + batch_idx)
    start_idx = tl.load(query_start_loc + batch_idx)
    end_idx = tl.load(query_start_loc + batch_idx + 1)
    for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
        offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
226
        positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
227
228
229

        block_indices = positions // (block_size * CP_SIZE)
        block_offsets = positions % (block_size * CP_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
230
        block_numbers = tl.load(
231
            block_table_ptr + req_state_idx * block_table_stride + block_indices
Woosuk Kwon's avatar
Woosuk Kwon committed
232
        )
233

234
235
236
237
238
239
240
241
242
243
244
        if CP_SIZE == 1:
            # Common case: Context parallelism is not used.
            slot_ids = block_numbers * block_size + block_offsets
        else:
            # Context parallelism is used.
            is_local = block_offsets // CP_INTERLEAVE % CP_SIZE == cp_rank
            rounds = block_offsets // (CP_INTERLEAVE * CP_SIZE)
            remainder = block_offsets % CP_INTERLEAVE
            local_offsets = rounds * CP_INTERLEAVE + remainder
            slot_ids = block_numbers * block_size + local_offsets
            slot_ids = tl.where(is_local, slot_ids, PAD_ID)
245

Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
248
249
250
251
252
253
        tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)


@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
    ptr = tl.load(ptr_to_ptr)
    ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
    return tl.multiple_of(ptr, 16)