block_table.py 8.77 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import numpy as np
import torch

7
from vllm.distributed import get_dcp_group
8
from vllm.logger import init_logger
9
from vllm.utils import cdiv
10
11
12
13
14
15
16
17

logger = init_logger(__name__)


class BlockTable:

    def __init__(
        self,
18
        block_size: int,
19
20
        max_num_reqs: int,
        max_num_blocks_per_req: int,
21
        max_num_batched_tokens: int,
22
23
24
        pin_memory: bool,
        device: torch.device,
    ):
25
        self.block_size = block_size
26
27
        self.max_num_reqs = max_num_reqs
        self.max_num_blocks_per_req = max_num_blocks_per_req
28
        self.max_num_batched_tokens = max_num_batched_tokens
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
        self.pin_memory = pin_memory
        self.device = device

        self.block_table = torch.zeros(
            (max_num_reqs, max_num_blocks_per_req),
            device=self.device,
            dtype=torch.int32,
        )
        self.block_table_cpu = torch.zeros(
            (max_num_reqs, max_num_blocks_per_req),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.block_table_np = self.block_table_cpu.numpy()
        self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)

46
47
48
49
50
51
52
53
        self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
                                            dtype=torch.int64,
                                            device="cpu",
                                            pin_memory=self.pin_memory)
        self.slot_mapping_np = self.slot_mapping_cpu.numpy()
        self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
                                        dtype=torch.int64,
                                        device=self.device)
54
55
56
57
58
59
60
        try:
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
61

62
63
    def append_row(
        self,
64
        block_ids: list[int],
65
        row_idx: int,
66
    ) -> None:
67
68
        if not block_ids:
            return
69
        num_blocks = len(block_ids)
70
71
        start = self.num_blocks_per_row[row_idx]
        self.num_blocks_per_row[row_idx] += num_blocks
72
73
        self.block_table_np[row_idx, start:start + num_blocks] = block_ids

74
    def add_row(self, block_ids: list[int], row_idx: int) -> None:
75
76
        self.num_blocks_per_row[row_idx] = 0
        self.append_row(block_ids, row_idx)
77
78
79
80
81
82
83

    def move_row(self, src: int, tgt: int) -> None:
        num_blocks = self.num_blocks_per_row[src]
        self.block_table_np[tgt, :num_blocks] = self.block_table_np[
            src, :num_blocks]
        self.num_blocks_per_row[tgt] = num_blocks

84
85
86
87
88
89
90
91
    def swap_row(self, src: int, tgt: int) -> None:
        num_blocks_src = self.num_blocks_per_row[src]
        num_blocks_tgt = self.num_blocks_per_row[tgt]
        self.num_blocks_per_row[src] = num_blocks_tgt
        self.num_blocks_per_row[tgt] = num_blocks_src

        self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]

92
93
94
95
96
97
98
99
    def compute_slot_mapping(self, req_indices: np.ndarray,
                             positions: np.ndarray) -> None:
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size`
        # here because M (max_model_len) is not necessarily divisible by
        # block_size.
100
        if self.dcp_world_size > 1:
101
            # Note(hc): The DCP implement store kvcache with an interleave
102
103
104
105
106
107
108
109
110
111
112
113
114
            # style, the kvcache for the token whose token_idx is i is
            # always stored on the GPU whose dcp_rank equals i % cp_world_size:

            # Use a "virtual block" which equals to world_size * block_size
            # for block_table_indices calculation.
            virtual_block_size = self.block_size * self.dcp_world_size
            block_table_indices = (req_indices * self.max_num_blocks_per_req +
                                   positions // virtual_block_size)
            block_numbers = self.block_table_np.ravel()[block_table_indices]
            # Use virtual_block_size for mask calculation, which marks local
            # tokens.
            virtual_block_offsets = positions % virtual_block_size
            mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
co63oc's avatar
co63oc committed
115
            # Calculate local block_offsets
116
            block_offsets = virtual_block_offsets // self.dcp_world_size
co63oc's avatar
co63oc committed
117
            # Calculate slot_mapping
118
119
120
121
122
123
124
125
126
127
128
129
            slot_mapping = block_numbers * self.block_size + block_offsets
            # Write final slots, use -1 for not-local
            self.slot_mapping_np[:req_indices.shape[0]] = np.where(
                mask, slot_mapping, -1)
        else:
            block_table_indices = (req_indices * self.max_num_blocks_per_req +
                                   positions // self.block_size)
            block_numbers = self.block_table_np.ravel()[block_table_indices]
            block_offsets = positions % self.block_size
            np.add(block_numbers * self.block_size,
                   block_offsets,
                   out=self.slot_mapping_np[:req_indices.shape[0]])
130
131

    def commit_block_table(self, num_reqs: int) -> None:
132
133
134
        self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
                                          non_blocking=True)

135
136
137
138
    def commit_slot_mapping(self, num_tokens: int) -> None:
        self.slot_mapping[:num_tokens].copy_(
            self.slot_mapping_cpu[:num_tokens], non_blocking=True)

139
140
141
142
143
    def clear(self) -> None:
        self.block_table.fill_(0)
        self.block_table_cpu.fill_(0)

    def get_device_tensor(self) -> torch.Tensor:
144
        """Returns the device tensor of the block table."""
145
146
147
148
149
150
151
152
153
        return self.block_table

    def get_cpu_tensor(self) -> torch.Tensor:
        """Returns the CPU tensor of the block table."""
        return self.block_table_cpu

    def get_numpy_array(self) -> np.ndarray:
        """Returns the numpy array of the block table."""
        return self.block_table_np
154
155
156
157
158


class MultiGroupBlockTable:
    """The BlockTables for each KV cache group."""

159
160
161
162
163
164
165
166
    def __init__(self,
                 max_num_reqs: int,
                 max_model_len: int,
                 max_num_batched_tokens: int,
                 pin_memory: bool,
                 device: torch.device,
                 block_sizes: list[int],
                 num_speculative_tokens: int = 0) -> None:
167
168
169
170
171
172
173
174
175
176
        # Note(hc): each dcp rank only store
        # (max_model_len//dcp_world_size) tokens in kvcache,
        # so the block_size which used for calc max_num_blocks_per_req
        # must be multiplied by dcp_world_size.
        try:
            dcp_world_size = get_dcp_group().world_size
        except AssertionError:
            # DCP might not be initialized in testing
            dcp_world_size = 1

177
        self.block_tables = [
178
179
180
181
182
            BlockTable(
                block_size, max_num_reqs,
                max(cdiv(max_model_len, block_size * dcp_world_size),
                    1 + num_speculative_tokens), max_num_batched_tokens,
                pin_memory, device) for block_size in block_sizes
183
184
        ]

185
186
    def append_row(self, block_ids: tuple[list[int], ...],
                   row_idx: int) -> None:
187
188
189
        for i, block_table in enumerate(self.block_tables):
            block_table.append_row(block_ids[i], row_idx)

190
    def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
191
192
193
194
195
196
197
198
199
200
201
        for i, block_table in enumerate(self.block_tables):
            block_table.add_row(block_ids[i], row_idx)

    def move_row(self, src: int, tgt: int) -> None:
        for block_table in self.block_tables:
            block_table.move_row(src, tgt)

    def swap_row(self, src: int, tgt: int) -> None:
        for block_table in self.block_tables:
            block_table.swap_row(src, tgt)

202
203
204
205
206
207
208
209
210
211
    def compute_slot_mapping(self, req_indices: np.ndarray,
                             positions: np.ndarray) -> None:
        for block_table in self.block_tables:
            block_table.compute_slot_mapping(req_indices, positions)

    def commit_block_table(self, num_reqs: int) -> None:
        for block_table in self.block_tables:
            block_table.commit_block_table(num_reqs)

    def commit_slot_mapping(self, num_tokens: int) -> None:
212
        for block_table in self.block_tables:
213
            block_table.commit_slot_mapping(num_tokens)
214
215
216
217
218
219
220
221

    def clear(self) -> None:
        for block_table in self.block_tables:
            block_table.clear()

    def __getitem__(self, idx: int) -> "BlockTable":
        """Returns the BlockTable for the i-th KV cache group."""
        return self.block_tables[idx]