block_table.py 8.77 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import numpy as np
import torch

7
from vllm.distributed import get_dcp_group
8
from vllm.logger import init_logger
9
from vllm.utils import cdiv
10
11
12
13
14
15
16
17

logger = init_logger(__name__)


class BlockTable:

    def __init__(
        self,
18
        block_size: int,
19
20
        max_num_reqs: int,
        max_num_blocks_per_req: int,
21
        max_num_batched_tokens: int,
22
23
24
        pin_memory: bool,
        device: torch.device,
    ):
25
        self.block_size = block_size
26
27
        self.max_num_reqs = max_num_reqs
        self.max_num_blocks_per_req = max_num_blocks_per_req
28
        self.max_num_batched_tokens = max_num_batched_tokens
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
        self.pin_memory = pin_memory
        self.device = device

        self.block_table = torch.zeros(
            (max_num_reqs, max_num_blocks_per_req),
            device=self.device,
            dtype=torch.int32,
        )
        self.block_table_cpu = torch.zeros(
            (max_num_reqs, max_num_blocks_per_req),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.block_table_np = self.block_table_cpu.numpy()
        self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)

46
47
48
49
50
51
52
53
        self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
                                            dtype=torch.int64,
                                            device="cpu",
                                            pin_memory=self.pin_memory)
        self.slot_mapping_np = self.slot_mapping_cpu.numpy()
        self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
                                        dtype=torch.int64,
                                        device=self.device)
54
55
56
57
58
59
60
        try:
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
61

62
63
    def append_row(
        self,
64
        block_ids: list[int],
65
        row_idx: int,
66
    ) -> None:
67
68
        if not block_ids:
            return
69
        num_blocks = len(block_ids)
70
71
        start = self.num_blocks_per_row[row_idx]
        self.num_blocks_per_row[row_idx] += num_blocks
72
73
        self.block_table_np[row_idx, start:start + num_blocks] = block_ids

74
    def add_row(self, block_ids: list[int], row_idx: int) -> None:
75
76
        self.num_blocks_per_row[row_idx] = 0
        self.append_row(block_ids, row_idx)
77
78
79
80
81
82
83

    def move_row(self, src: int, tgt: int) -> None:
        num_blocks = self.num_blocks_per_row[src]
        self.block_table_np[tgt, :num_blocks] = self.block_table_np[
            src, :num_blocks]
        self.num_blocks_per_row[tgt] = num_blocks

84
85
86
87
88
89
90
91
    def swap_row(self, src: int, tgt: int) -> None:
        num_blocks_src = self.num_blocks_per_row[src]
        num_blocks_tgt = self.num_blocks_per_row[tgt]
        self.num_blocks_per_row[src] = num_blocks_tgt
        self.num_blocks_per_row[tgt] = num_blocks_src

        self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]

92
93
94
95
96
97
98
99
    def compute_slot_mapping(self, req_indices: np.ndarray,
                             positions: np.ndarray) -> None:
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size`
        # here because M (max_model_len) is not necessarily divisible by
        # block_size.
100
        if self.dcp_world_size > 1:
101
            # Note(hc): The DCP implement store kvcache with an interleave
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
            # style, the kvcache for the token whose token_idx is i is
            # always stored on the GPU whose dcp_rank equals i % cp_world_size:

            # Use a "virtual block" which equals to world_size * block_size
            # for block_table_indices calculation.
            virtual_block_size = self.block_size * self.dcp_world_size
            block_table_indices = (req_indices * self.max_num_blocks_per_req +
                                   positions // virtual_block_size)
            block_numbers = self.block_table_np.ravel()[block_table_indices]
            # Use virtual_block_size for mask calculation, which marks local
            # tokens.
            virtual_block_offsets = positions % virtual_block_size
            mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
            # Calcuate local block_offsets
            block_offsets = virtual_block_offsets // self.dcp_world_size
            # Calcuate slot_mapping
            slot_mapping = block_numbers * self.block_size + block_offsets
            # Write final slots, use -1 for not-local
            self.slot_mapping_np[:req_indices.shape[0]] = np.where(
                mask, slot_mapping, -1)
        else:
            block_table_indices = (req_indices * self.max_num_blocks_per_req +
                                   positions // self.block_size)
            block_numbers = self.block_table_np.ravel()[block_table_indices]
            block_offsets = positions % self.block_size
            np.add(block_numbers * self.block_size,
                   block_offsets,
                   out=self.slot_mapping_np[:req_indices.shape[0]])
130
131

    def commit_block_table(self, num_reqs: int) -> None:
132
133
134
        self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
                                          non_blocking=True)

135
136
137
138
    def commit_slot_mapping(self, num_tokens: int) -> None:
        self.slot_mapping[:num_tokens].copy_(
            self.slot_mapping_cpu[:num_tokens], non_blocking=True)

139
140
141
142
143
    def clear(self) -> None:
        self.block_table.fill_(0)
        self.block_table_cpu.fill_(0)

    def get_device_tensor(self) -> torch.Tensor:
144
        """Returns the device tensor of the block table."""
145
146
147
148
149
150
151
152
153
        return self.block_table

    def get_cpu_tensor(self) -> torch.Tensor:
        """Returns the CPU tensor of the block table."""
        return self.block_table_cpu

    def get_numpy_array(self) -> np.ndarray:
        """Returns the numpy array of the block table."""
        return self.block_table_np
154
155
156
157
158


class MultiGroupBlockTable:
    """The BlockTables for each KV cache group."""

159
160
161
162
163
164
165
166
    def __init__(self,
                 max_num_reqs: int,
                 max_model_len: int,
                 max_num_batched_tokens: int,
                 pin_memory: bool,
                 device: torch.device,
                 block_sizes: list[int],
                 num_speculative_tokens: int = 0) -> None:
167
168
169
170
171
172
173
174
175
176
        # Note(hc): each dcp rank only store
        # (max_model_len//dcp_world_size) tokens in kvcache,
        # so the block_size which used for calc max_num_blocks_per_req
        # must be multiplied by dcp_world_size.
        try:
            dcp_world_size = get_dcp_group().world_size
        except AssertionError:
            # DCP might not be initialized in testing
            dcp_world_size = 1

177
        self.block_tables = [
178
179
180
181
182
            BlockTable(
                block_size, max_num_reqs,
                max(cdiv(max_model_len, block_size * dcp_world_size),
                    1 + num_speculative_tokens), max_num_batched_tokens,
                pin_memory, device) for block_size in block_sizes
183
184
        ]

185
186
    def append_row(self, block_ids: tuple[list[int], ...],
                   row_idx: int) -> None:
187
188
189
        for i, block_table in enumerate(self.block_tables):
            block_table.append_row(block_ids[i], row_idx)

190
    def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
191
192
193
194
195
196
197
198
199
200
201
        for i, block_table in enumerate(self.block_tables):
            block_table.add_row(block_ids[i], row_idx)

    def move_row(self, src: int, tgt: int) -> None:
        for block_table in self.block_tables:
            block_table.move_row(src, tgt)

    def swap_row(self, src: int, tgt: int) -> None:
        for block_table in self.block_tables:
            block_table.swap_row(src, tgt)

202
203
204
205
206
207
208
209
210
211
    def compute_slot_mapping(self, req_indices: np.ndarray,
                             positions: np.ndarray) -> None:
        for block_table in self.block_tables:
            block_table.compute_slot_mapping(req_indices, positions)

    def commit_block_table(self, num_reqs: int) -> None:
        for block_table in self.block_tables:
            block_table.commit_block_table(num_reqs)

    def commit_slot_mapping(self, num_tokens: int) -> None:
212
        for block_table in self.block_tables:
213
            block_table.commit_slot_mapping(num_tokens)
214
215
216
217
218
219
220
221

    def clear(self) -> None:
        for block_table in self.block_tables:
            block_table.clear()

    def __getitem__(self, idx: int) -> "BlockTable":
        """Returns the BlockTable for the i-th KV cache group."""
        return self.block_tables[idx]