block_table.py 10.2 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import torch

7
from vllm.distributed import get_dcp_group
8
from vllm.triton_utils import tl, triton
Woosuk Kwon's avatar
Woosuk Kwon committed
9
from vllm.utils.math_utils import cdiv
10
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
11
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
14
15
16
17
18
19
20
21


class BlockTables:
    def __init__(
        self,
        block_sizes: list[int],
        max_num_reqs: int,
        max_num_batched_tokens: int,
        max_model_len: int,
        device: torch.device,
22
        cp_kv_cache_interleave_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
25
26
27
28
    ):
        self.block_sizes = block_sizes
        self.max_num_reqs = max_num_reqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.max_model_len = max_model_len
        self.device = device
29
30
31
32
33
34
35
36
37
38
39
        assert cp_kv_cache_interleave_size >= 1
        self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size

        try:
            dcp = get_dcp_group()
            self.dcp_world_size, self.dcp_rank = dcp.world_size, dcp.rank_in_group
        except AssertionError:
            self.dcp_world_size, self.dcp_rank = 1, 0
        # TODO(wentao): PCP supprot
        self.total_cp_world_size = self.dcp_world_size
        self.total_cp_rank = self.dcp_rank
40

Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
        self.num_kv_cache_groups = len(self.block_sizes)
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
43
        self.block_tables: list[StagedWriteTensor] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
        for i in range(self.num_kv_cache_groups):
            block_size = self.block_sizes[i]
46
47
48
49
50
51
            # with DCP, a request's KV is sharded across
            # ranks, so one physical block on this rank
            # corresponds to `block_size * total_cp_world_size`
            # tokens in the global (unsharded) sequence.
            virtual_block_size = block_size * self.total_cp_world_size
            max_num_blocks = cdiv(self.max_model_len, virtual_block_size)
52
53
            block_table = StagedWriteTensor(
                (self.max_num_reqs, max_num_blocks),
Woosuk Kwon's avatar
Woosuk Kwon committed
54
                dtype=torch.int32,
55
                device=device,
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
            )
            self.block_tables.append(block_table)
58
59
60
        self.block_table_ptrs = self._make_ptr_tensor(
            [b.gpu for b in self.block_tables]
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
61
        self.block_table_strides = torch.tensor(
62
            [b.gpu.stride(0) for b in self.block_tables],
Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
65
            dtype=torch.int64,
            device=self.device,
        )
66

Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
69
        self.block_sizes_tensor = torch.tensor(
            self.block_sizes, dtype=torch.int32, device=self.device
        )
70
71
        self.num_blocks = UvaBackedTensor(
            (self.num_kv_cache_groups, self.max_num_reqs),
Woosuk Kwon's avatar
Woosuk Kwon committed
72
73
            dtype=torch.int32,
        )
74
75
76
77
78
79
80
81

        # Block tables used for model's forward pass.
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
        self.input_block_tables: list[torch.Tensor] = [
            torch.zeros_like(b.gpu) for b in self.block_tables
        ]
        self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)

Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
85
86
87
88
89
90
        self.slot_mappings = torch.zeros(
            self.num_kv_cache_groups,
            self.max_num_batched_tokens,
            dtype=torch.int64,
            device=self.device,
        )

    def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
        # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
91
        return torch.tensor(
92
            [t.data_ptr() for t in x], dtype=torch.uint64, device=self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
95
96
        )

    def append_block_ids(
        self,
97
        req_index: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
98
        new_block_ids: tuple[list[int], ...],
99
        overwrite: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
    ) -> None:
        for i in range(self.num_kv_cache_groups):
102
            start = self.num_blocks.np[i, req_index] if not overwrite else 0
103
            block_ids = new_block_ids[i]
104
105
            self.block_tables[i].stage_write(req_index, start, block_ids)
            self.num_blocks.np[i, req_index] = start + len(block_ids)
106

107
108
109
110
111
112
    def apply_staged_writes(self) -> None:
        # TODO(woosuk): This can be inefficient since it launches one kernel per
        # block table. Implement a kernel to handle all block tables at once.
        for block_table in self.block_tables:
            block_table.apply_write()
        self.num_blocks.copy_to_uva()
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114

    def gather_block_tables(
115
        self, idx_mapping: torch.Tensor
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
118
119
120
121
122
    ) -> tuple[torch.Tensor, ...]:
        num_reqs = idx_mapping.shape[0]
        _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
            idx_mapping,
            self.block_table_ptrs,
            self.input_block_table_ptrs,
            self.block_table_strides,
123
124
            self.num_blocks.gpu,
            self.num_blocks.gpu.stride(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
128
129
130
131
132
133
            BLOCK_SIZE=1024,  # type: ignore
        )
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def compute_slot_mappings(
        self,
134
        idx_mapping: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
137
        query_start_loc: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
138
        num_reqs = idx_mapping.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
141
142
143
        num_tokens = positions.shape[0]
        num_groups = self.num_kv_cache_groups
        _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
            num_tokens,
            self.max_num_batched_tokens,
144
            idx_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
            query_start_loc,
            positions,
147
            self.block_table_ptrs,
Woosuk Kwon's avatar
Woosuk Kwon committed
148
149
150
151
            self.block_table_strides,
            self.block_sizes_tensor,
            self.slot_mappings,
            self.slot_mappings.stride(0),
152
153
154
            TOTAL_CP_WORLD_SIZE=self.total_cp_world_size,
            TOTAL_CP_RANK=self.total_cp_rank,
            CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
155
            PAD_ID=PAD_SLOT_ID,
156
            TRITON_BLOCK_SIZE=1024,  # type: ignore
Woosuk Kwon's avatar
Woosuk Kwon committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        )
        return self.slot_mappings[:, :num_tokens]

    def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
        self.slot_mappings.fill_(PAD_SLOT_ID)
        return self.slot_mappings[:, :num_tokens]


@triton.jit
def _gather_block_tables_kernel(
    batch_idx_to_req_idx,  # [batch_size]
    src_block_table_ptrs,  # [num_kv_cache_groups]
    dst_block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
    num_blocks_ptr,  # [num_kv_cache_groups, max_num_reqs]
    num_blocks_stride,
    BLOCK_SIZE: tl.constexpr,
):
    # kv cache group id
    group_id = tl.program_id(0)
    batch_idx = tl.program_id(1)
    req_idx = tl.load(batch_idx_to_req_idx + batch_idx)

    group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
    num_blocks = tl.load(group_num_blocks_ptr + req_idx)

    stride = tl.load(block_table_strides + group_id)
    src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
    src_row_ptr = src_block_table_ptr + req_idx * stride
    dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
    dst_row_ptr = dst_block_table_ptr + batch_idx * stride

    for i in tl.range(0, num_blocks, BLOCK_SIZE):
        offset = i + tl.arange(0, BLOCK_SIZE)
        block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
        tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)


@triton.jit
def _compute_slot_mappings_kernel(
    num_tokens,
    max_num_tokens,
199
200
    idx_mapping,  # [num_reqs]
    query_start_loc,  # [num_reqs + 1]
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
203
    pos,  # [num_tokens]
    block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
204
    block_sizes,  # [num_kv_cache_groups]
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
    slot_mappings_ptr,  # [num_kv_cache_groups, max_num_tokens]
    slot_mappings_stride,
207
208
209
    TOTAL_CP_WORLD_SIZE: tl.constexpr,
    TOTAL_CP_RANK: tl.constexpr,
    CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
Woosuk Kwon's avatar
Woosuk Kwon committed
210
    PAD_ID: tl.constexpr,
211
    TRITON_BLOCK_SIZE: tl.constexpr,
Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
214
):
    # kv cache group id
    group_id = tl.program_id(0)
215
    batch_idx = tl.program_id(1)
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
    slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride

218
    if batch_idx == tl.num_programs(1) - 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
219
        # Pad remaining slots to -1. This is needed for CUDA graphs.
220
221
        for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
            offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
222
223
224
225
226
            tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
        return

    block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
    block_table_stride = tl.load(block_table_strides + group_id)
227
    block_size = tl.load(block_sizes + group_id)
228
    virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
Woosuk Kwon's avatar
Woosuk Kwon committed
229

230
231
232
233
234
    req_state_idx = tl.load(idx_mapping + batch_idx)
    start_idx = tl.load(query_start_loc + batch_idx)
    end_idx = tl.load(query_start_loc + batch_idx + 1)
    for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
        offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
235
        positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
236
        block_indices = positions // virtual_block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
237
        block_numbers = tl.load(
238
            block_table_ptr + req_state_idx * block_table_stride + block_indices
Woosuk Kwon's avatar
Woosuk Kwon committed
239
        )
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        virtual_block_offsets = positions - block_indices * virtual_block_size

        # determine whether the token is stored on this CP rank.
        is_local = (
            virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
        ) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
        # mapping virture block offsets to local block offsets.
        local_block_offsets = (
            virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
        ) * CP_KV_CACHE_INTERLEAVE_SIZE + (
            virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
        )

        # physical slot index
        slot_ids = block_numbers * block_size + local_block_offsets
        slot_ids = tl.where(is_local, slot_ids, PAD_ID)
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
258
259
260
261
262
263
        tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)


@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
    ptr = tl.load(ptr_to_ptr)
    ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
    return tl.multiple_of(ptr, 16)