block_table.py 8.33 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import torch

7
from vllm.triton_utils import tl, triton
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.utils.math_utils import cdiv
9
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
10
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class BlockTables:
    def __init__(
        self,
        block_sizes: list[int],
        max_num_reqs: int,
        max_num_batched_tokens: int,
        max_model_len: int,
        device: torch.device,
    ):
        self.block_sizes = block_sizes
        self.max_num_reqs = max_num_reqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.max_model_len = max_model_len
        self.device = device
27

Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
        self.num_kv_cache_groups = len(self.block_sizes)
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
30
        self.block_tables: list[StagedWriteTensor] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
        for i in range(self.num_kv_cache_groups):
            block_size = self.block_sizes[i]
            max_num_blocks = cdiv(self.max_model_len, block_size)
34
35
            block_table = StagedWriteTensor(
                (self.max_num_reqs, max_num_blocks),
Woosuk Kwon's avatar
Woosuk Kwon committed
36
                dtype=torch.int32,
37
                device=device,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
            )
            self.block_tables.append(block_table)
40
41
42
        self.block_table_ptrs = self._make_ptr_tensor(
            [b.gpu for b in self.block_tables]
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
43
        self.block_table_strides = torch.tensor(
44
            [b.gpu.stride(0) for b in self.block_tables],
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46
47
            dtype=torch.int64,
            device=self.device,
        )
48

Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51
        self.block_sizes_tensor = torch.tensor(
            self.block_sizes, dtype=torch.int32, device=self.device
        )
52
53
        self.num_blocks = UvaBackedTensor(
            (self.num_kv_cache_groups, self.max_num_reqs),
Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
            dtype=torch.int32,
        )
56
57
58
59
60
61
62
63

        # Block tables used for model's forward pass.
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
        self.input_block_tables: list[torch.Tensor] = [
            torch.zeros_like(b.gpu) for b in self.block_tables
        ]
        self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)

Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
66
67
68
69
70
71
72
        self.slot_mappings = torch.zeros(
            self.num_kv_cache_groups,
            self.max_num_batched_tokens,
            dtype=torch.int64,
            device=self.device,
        )

    def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
        # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
73
        return torch.tensor(
Woosuk Kwon's avatar
Woosuk Kwon committed
74
75
            [t.data_ptr() for t in x],
            dtype=torch.uint64,
76
            device=self.device,
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
80
        )

    def append_block_ids(
        self,
81
        req_index: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
82
        new_block_ids: tuple[list[int], ...],
83
        overwrite: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
    ) -> None:
        for i in range(self.num_kv_cache_groups):
86
            start = self.num_blocks.np[i, req_index] if not overwrite else 0
87
            block_ids = new_block_ids[i]
88
89
            self.block_tables[i].stage_write(req_index, start, block_ids)
            self.num_blocks.np[i, req_index] = start + len(block_ids)
90

91
92
93
94
95
96
    def apply_staged_writes(self) -> None:
        # TODO(woosuk): This can be inefficient since it launches one kernel per
        # block table. Implement a kernel to handle all block tables at once.
        for block_table in self.block_tables:
            block_table.apply_write()
        self.num_blocks.copy_to_uva()
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
101
102
103
104
105
106
107

    def gather_block_tables(
        self,
        idx_mapping: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
        num_reqs = idx_mapping.shape[0]
        _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
            idx_mapping,
            self.block_table_ptrs,
            self.input_block_table_ptrs,
            self.block_table_strides,
108
109
            self.num_blocks.gpu,
            self.num_blocks.gpu.stride(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
112
113
114
115
116
117
118
            BLOCK_SIZE=1024,  # type: ignore
        )
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def compute_slot_mappings(
        self,
119
        idx_mapping: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
120
121
122
        query_start_loc: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
123
        num_reqs = idx_mapping.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
126
127
128
        num_tokens = positions.shape[0]
        num_groups = self.num_kv_cache_groups
        _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
            num_tokens,
            self.max_num_batched_tokens,
129
            idx_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
            query_start_loc,
            positions,
132
            self.block_table_ptrs,
Woosuk Kwon's avatar
Woosuk Kwon committed
133
134
135
136
137
            self.block_table_strides,
            self.block_sizes_tensor,
            self.slot_mappings,
            self.slot_mappings.stride(0),
            PAD_ID=PAD_SLOT_ID,
138
            TRITON_BLOCK_SIZE=1024,  # type: ignore
Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        )
        return self.slot_mappings[:, :num_tokens]

    def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
        self.slot_mappings.fill_(PAD_SLOT_ID)
        return self.slot_mappings[:, :num_tokens]


@triton.jit
def _gather_block_tables_kernel(
    batch_idx_to_req_idx,  # [batch_size]
    src_block_table_ptrs,  # [num_kv_cache_groups]
    dst_block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
    num_blocks_ptr,  # [num_kv_cache_groups, max_num_reqs]
    num_blocks_stride,
    BLOCK_SIZE: tl.constexpr,
):
    # kv cache group id
    group_id = tl.program_id(0)
    batch_idx = tl.program_id(1)
    req_idx = tl.load(batch_idx_to_req_idx + batch_idx)

    group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
    num_blocks = tl.load(group_num_blocks_ptr + req_idx)

    stride = tl.load(block_table_strides + group_id)
    src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
    src_row_ptr = src_block_table_ptr + req_idx * stride
    dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
    dst_row_ptr = dst_block_table_ptr + batch_idx * stride

    for i in tl.range(0, num_blocks, BLOCK_SIZE):
        offset = i + tl.arange(0, BLOCK_SIZE)
        block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
        tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)


@triton.jit
def _compute_slot_mappings_kernel(
    num_tokens,
    max_num_tokens,
181
182
    idx_mapping,  # [num_reqs]
    query_start_loc,  # [num_reqs + 1]
Woosuk Kwon's avatar
Woosuk Kwon committed
183
184
185
    pos,  # [num_tokens]
    block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
186
    block_sizes,  # [num_kv_cache_groups]
Woosuk Kwon's avatar
Woosuk Kwon committed
187
188
189
    slot_mappings_ptr,  # [num_kv_cache_groups, max_num_tokens]
    slot_mappings_stride,
    PAD_ID: tl.constexpr,
190
    TRITON_BLOCK_SIZE: tl.constexpr,
Woosuk Kwon's avatar
Woosuk Kwon committed
191
192
193
):
    # kv cache group id
    group_id = tl.program_id(0)
194
    batch_idx = tl.program_id(1)
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
    slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride

197
    if batch_idx == tl.num_programs(1) - 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
198
        # Pad remaining slots to -1. This is needed for CUDA graphs.
199
200
        for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
            offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
203
204
205
            tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
        return

    block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
    block_table_stride = tl.load(block_table_strides + group_id)
206
    block_size = tl.load(block_sizes + group_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
207

208
209
210
211
212
    req_state_idx = tl.load(idx_mapping + batch_idx)
    start_idx = tl.load(query_start_loc + batch_idx)
    end_idx = tl.load(query_start_loc + batch_idx + 1)
    for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
        offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
213
        positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
214
        block_indices = positions // block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
215
        block_numbers = tl.load(
216
            block_table_ptr + req_state_idx * block_table_stride + block_indices
Woosuk Kwon's avatar
Woosuk Kwon committed
217
        )
218
        slot_ids = block_numbers * block_size + positions % block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
221
222
223
224
225
226
        tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)


@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
    ptr = tl.load(ptr_to_ptr)
    ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
    return tl.multiple_of(ptr, 16)