block_table.py 8.13 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import torch

7
from vllm.triton_utils import tl, triton
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.utils.math_utils import cdiv
9
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
10
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class BlockTables:
    def __init__(
        self,
        block_sizes: list[int],
        max_num_reqs: int,
        max_num_batched_tokens: int,
        max_model_len: int,
        device: torch.device,
    ):
        self.block_sizes = block_sizes
        self.max_num_reqs = max_num_reqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.max_model_len = max_model_len
        self.device = device
27

Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
        self.num_kv_cache_groups = len(self.block_sizes)
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
30
        self.block_tables: list[StagedWriteTensor] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
        for i in range(self.num_kv_cache_groups):
            block_size = self.block_sizes[i]
            max_num_blocks = cdiv(self.max_model_len, block_size)
34
35
            block_table = StagedWriteTensor(
                (self.max_num_reqs, max_num_blocks),
Woosuk Kwon's avatar
Woosuk Kwon committed
36
                dtype=torch.int32,
37
                device=device,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
            )
            self.block_tables.append(block_table)
40
41
42
        self.block_table_ptrs = self._make_ptr_tensor(
            [b.gpu for b in self.block_tables]
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
43
        self.block_table_strides = torch.tensor(
44
            [b.gpu.stride(0) for b in self.block_tables],
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46
47
            dtype=torch.int64,
            device=self.device,
        )
48

Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51
        self.block_sizes_tensor = torch.tensor(
            self.block_sizes, dtype=torch.int32, device=self.device
        )
52
53
        self.num_blocks = UvaBackedTensor(
            (self.num_kv_cache_groups, self.max_num_reqs),
Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
            dtype=torch.int32,
        )
56
57
58
59
60
61
62
63

        # Block tables used for model's forward pass.
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
        self.input_block_tables: list[torch.Tensor] = [
            torch.zeros_like(b.gpu) for b in self.block_tables
        ]
        self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)

Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
66
67
68
69
70
71
72
        self.slot_mappings = torch.zeros(
            self.num_kv_cache_groups,
            self.max_num_batched_tokens,
            dtype=torch.int64,
            device=self.device,
        )

    def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
        # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
73
        return torch.tensor(
Woosuk Kwon's avatar
Woosuk Kwon committed
74
75
            [t.data_ptr() for t in x],
            dtype=torch.uint64,
76
            device=self.device,
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
80
        )

    def append_block_ids(
        self,
81
        req_index: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
82
        new_block_ids: tuple[list[int], ...],
83
        overwrite: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
    ) -> None:
        for i in range(self.num_kv_cache_groups):
86
            start = self.num_blocks.np[i, req_index] if not overwrite else 0
87
            block_ids = new_block_ids[i]
88
89
            self.block_tables[i].stage_write(req_index, start, block_ids)
            self.num_blocks.np[i, req_index] = start + len(block_ids)
90

91
92
93
94
95
96
    def apply_staged_writes(self) -> None:
        # TODO(woosuk): This can be inefficient since it launches one kernel per
        # block table. Implement a kernel to handle all block tables at once.
        for block_table in self.block_tables:
            block_table.apply_write()
        self.num_blocks.copy_to_uva()
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
101
102
103
104
105
106
107

    def gather_block_tables(
        self,
        idx_mapping: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
        num_reqs = idx_mapping.shape[0]
        _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
            idx_mapping,
            self.block_table_ptrs,
            self.input_block_table_ptrs,
            self.block_table_strides,
108
109
            self.num_blocks.gpu,
            self.num_blocks.gpu.stride(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
            BLOCK_SIZE=1024,  # type: ignore
        )
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def compute_slot_mappings(
        self,
        query_start_loc: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
        num_reqs = query_start_loc.shape[0] - 1
        num_tokens = positions.shape[0]
        num_groups = self.num_kv_cache_groups
        _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
            num_tokens,
            self.max_num_batched_tokens,
            query_start_loc,
            positions,
            self.input_block_table_ptrs,
            self.block_table_strides,
            self.block_sizes_tensor,
            self.slot_mappings,
            self.slot_mappings.stride(0),
            PAD_ID=PAD_SLOT_ID,
            BLOCK_SIZE=1024,  # type: ignore
        )
        return self.slot_mappings[:, :num_tokens]

    def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
        self.slot_mappings.fill_(PAD_SLOT_ID)
        return self.slot_mappings[:, :num_tokens]


@triton.jit
def _gather_block_tables_kernel(
    batch_idx_to_req_idx,  # [batch_size]
    src_block_table_ptrs,  # [num_kv_cache_groups]
    dst_block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
    num_blocks_ptr,  # [num_kv_cache_groups, max_num_reqs]
    num_blocks_stride,
    BLOCK_SIZE: tl.constexpr,
):
    # kv cache group id
    group_id = tl.program_id(0)
    batch_idx = tl.program_id(1)
    req_idx = tl.load(batch_idx_to_req_idx + batch_idx)

    group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
    num_blocks = tl.load(group_num_blocks_ptr + req_idx)

    stride = tl.load(block_table_strides + group_id)
    src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
    src_row_ptr = src_block_table_ptr + req_idx * stride
    dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
    dst_row_ptr = dst_block_table_ptr + batch_idx * stride

    for i in tl.range(0, num_blocks, BLOCK_SIZE):
        offset = i + tl.arange(0, BLOCK_SIZE)
        block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
        tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)


@triton.jit
def _compute_slot_mappings_kernel(
    num_tokens,
    max_num_tokens,
    cu_num_tokens,  # [num_reqs + 1]
    pos,  # [num_tokens]
    block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
    page_sizes,  # [num_kv_cache_groups]
    slot_mappings_ptr,  # [num_kv_cache_groups, max_num_tokens]
    slot_mappings_stride,
    PAD_ID: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    # kv cache group id
    group_id = tl.program_id(0)
    req_idx = tl.program_id(1)
    slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride

    if req_idx == tl.num_programs(1) - 1:
        # Pad remaining slots to -1. This is needed for CUDA graphs.
        for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
            offset = i + tl.arange(0, BLOCK_SIZE)
            tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
        return

    block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
    block_table_stride = tl.load(block_table_strides + group_id)
    page_size = tl.load(page_sizes + group_id)

    start_idx = tl.load(cu_num_tokens + req_idx)
    end_idx = tl.load(cu_num_tokens + req_idx + 1)
    for i in range(start_idx, end_idx, BLOCK_SIZE):
        offset = i + tl.arange(0, BLOCK_SIZE)
        positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
        block_indices = positions // page_size
        block_numbers = tl.load(
            block_table_ptr + req_idx * block_table_stride + block_indices
        )
        slot_ids = block_numbers * page_size + positions % page_size
        tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)


@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
    ptr = tl.load(ptr_to_ptr)
    ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
    return tl.multiple_of(ptr, 16)