block_table.py 8.29 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import torch

7
from vllm.triton_utils import tl, triton
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.utils.math_utils import cdiv
9
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
10
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class BlockTables:
    def __init__(
        self,
        block_sizes: list[int],
        max_num_reqs: int,
        max_num_batched_tokens: int,
        max_model_len: int,
        device: torch.device,
    ):
        self.block_sizes = block_sizes
        self.max_num_reqs = max_num_reqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.max_model_len = max_model_len
        self.device = device
27

Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
        self.num_kv_cache_groups = len(self.block_sizes)
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
30
        self.block_tables: list[StagedWriteTensor] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
        for i in range(self.num_kv_cache_groups):
            block_size = self.block_sizes[i]
            max_num_blocks = cdiv(self.max_model_len, block_size)
34
35
            block_table = StagedWriteTensor(
                (self.max_num_reqs, max_num_blocks),
Woosuk Kwon's avatar
Woosuk Kwon committed
36
                dtype=torch.int32,
37
                device=device,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
            )
            self.block_tables.append(block_table)
40
41
42
        self.block_table_ptrs = self._make_ptr_tensor(
            [b.gpu for b in self.block_tables]
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
43
        self.block_table_strides = torch.tensor(
44
            [b.gpu.stride(0) for b in self.block_tables],
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46
47
            dtype=torch.int64,
            device=self.device,
        )
48

Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51
        self.block_sizes_tensor = torch.tensor(
            self.block_sizes, dtype=torch.int32, device=self.device
        )
52
53
        self.num_blocks = UvaBackedTensor(
            (self.num_kv_cache_groups, self.max_num_reqs),
Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
            dtype=torch.int32,
        )
56
57
58
59
60
61
62
63

        # Block tables used for model's forward pass.
        # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
        self.input_block_tables: list[torch.Tensor] = [
            torch.zeros_like(b.gpu) for b in self.block_tables
        ]
        self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)

Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
66
67
68
69
70
71
72
        self.slot_mappings = torch.zeros(
            self.num_kv_cache_groups,
            self.max_num_batched_tokens,
            dtype=torch.int64,
            device=self.device,
        )

    def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
        # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
73
        return torch.tensor(
74
            [t.data_ptr() for t in x], dtype=torch.uint64, device=self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
78
        )

    def append_block_ids(
        self,
79
        req_index: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
80
        new_block_ids: tuple[list[int], ...],
81
        overwrite: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
    ) -> None:
        for i in range(self.num_kv_cache_groups):
84
            start = self.num_blocks.np[i, req_index] if not overwrite else 0
85
            block_ids = new_block_ids[i]
86
87
            self.block_tables[i].stage_write(req_index, start, block_ids)
            self.num_blocks.np[i, req_index] = start + len(block_ids)
88

89
90
91
92
93
94
    def apply_staged_writes(self) -> None:
        # TODO(woosuk): This can be inefficient since it launches one kernel per
        # block table. Implement a kernel to handle all block tables at once.
        for block_table in self.block_tables:
            block_table.apply_write()
        self.num_blocks.copy_to_uva()
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96

    def gather_block_tables(
97
        self, idx_mapping: torch.Tensor
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
101
102
103
104
    ) -> tuple[torch.Tensor, ...]:
        num_reqs = idx_mapping.shape[0]
        _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
            idx_mapping,
            self.block_table_ptrs,
            self.input_block_table_ptrs,
            self.block_table_strides,
105
106
            self.num_blocks.gpu,
            self.num_blocks.gpu.stride(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108
109
110
111
112
113
114
115
            BLOCK_SIZE=1024,  # type: ignore
        )
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
        return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)

    def compute_slot_mappings(
        self,
116
        idx_mapping: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
119
        query_start_loc: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
120
        num_reqs = idx_mapping.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
124
125
        num_tokens = positions.shape[0]
        num_groups = self.num_kv_cache_groups
        _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
            num_tokens,
            self.max_num_batched_tokens,
126
            idx_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
            query_start_loc,
            positions,
129
            self.block_table_ptrs,
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
134
            self.block_table_strides,
            self.block_sizes_tensor,
            self.slot_mappings,
            self.slot_mappings.stride(0),
            PAD_ID=PAD_SLOT_ID,
135
            TRITON_BLOCK_SIZE=1024,  # type: ignore
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        )
        return self.slot_mappings[:, :num_tokens]

    def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
        self.slot_mappings.fill_(PAD_SLOT_ID)
        return self.slot_mappings[:, :num_tokens]


@triton.jit
def _gather_block_tables_kernel(
    batch_idx_to_req_idx,  # [batch_size]
    src_block_table_ptrs,  # [num_kv_cache_groups]
    dst_block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
    num_blocks_ptr,  # [num_kv_cache_groups, max_num_reqs]
    num_blocks_stride,
    BLOCK_SIZE: tl.constexpr,
):
    # kv cache group id
    group_id = tl.program_id(0)
    batch_idx = tl.program_id(1)
    req_idx = tl.load(batch_idx_to_req_idx + batch_idx)

    group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
    num_blocks = tl.load(group_num_blocks_ptr + req_idx)

    stride = tl.load(block_table_strides + group_id)
    src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
    src_row_ptr = src_block_table_ptr + req_idx * stride
    dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
    dst_row_ptr = dst_block_table_ptr + batch_idx * stride

    for i in tl.range(0, num_blocks, BLOCK_SIZE):
        offset = i + tl.arange(0, BLOCK_SIZE)
        block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
        tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)


@triton.jit
def _compute_slot_mappings_kernel(
    num_tokens,
    max_num_tokens,
178
179
    idx_mapping,  # [num_reqs]
    query_start_loc,  # [num_reqs + 1]
Woosuk Kwon's avatar
Woosuk Kwon committed
180
181
182
    pos,  # [num_tokens]
    block_table_ptrs,  # [num_kv_cache_groups]
    block_table_strides,  # [num_kv_cache_groups]
183
    block_sizes,  # [num_kv_cache_groups]
Woosuk Kwon's avatar
Woosuk Kwon committed
184
185
186
    slot_mappings_ptr,  # [num_kv_cache_groups, max_num_tokens]
    slot_mappings_stride,
    PAD_ID: tl.constexpr,
187
    TRITON_BLOCK_SIZE: tl.constexpr,
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
):
    # kv cache group id
    group_id = tl.program_id(0)
191
    batch_idx = tl.program_id(1)
Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
    slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride

194
    if batch_idx == tl.num_programs(1) - 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
195
        # Pad remaining slots to -1. This is needed for CUDA graphs.
196
197
        for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
            offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
200
201
202
            tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
        return

    block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
    block_table_stride = tl.load(block_table_strides + group_id)
203
    block_size = tl.load(block_sizes + group_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
204

205
206
207
208
209
    req_state_idx = tl.load(idx_mapping + batch_idx)
    start_idx = tl.load(query_start_loc + batch_idx)
    end_idx = tl.load(query_start_loc + batch_idx + 1)
    for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
        offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
210
        positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
211
        block_indices = positions // block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
212
        block_numbers = tl.load(
213
            block_table_ptr + req_state_idx * block_table_stride + block_indices
Woosuk Kwon's avatar
Woosuk Kwon committed
214
        )
215
        slot_ids = block_numbers * block_size + positions % block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
218
219
220
221
222
223
        tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)


@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
    ptr = tl.load(ptr_to_ptr)
    ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
    return tl.multiple_of(ptr, 16)