block_table.py 8.35 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Union
4

5
6
7
import numpy as np
import torch

8
from vllm.distributed import get_dcp_group
9
from vllm.logger import init_logger
10
from vllm.utils import cdiv
11
from vllm.v1.utils import CpuGpuBuffer
12
13
14
15
16
17
18
19

logger = init_logger(__name__)


class BlockTable:

    def __init__(
        self,
20
        block_size: int,
21
22
        max_num_reqs: int,
        max_num_blocks_per_req: int,
23
        max_num_batched_tokens: int,
24
25
26
        pin_memory: bool,
        device: torch.device,
    ):
27
        self.block_size = block_size
28
29
        self.max_num_reqs = max_num_reqs
        self.max_num_blocks_per_req = max_num_blocks_per_req
30
        self.max_num_batched_tokens = max_num_batched_tokens
31
32
33
        self.pin_memory = pin_memory
        self.device = device

34
35
36
        self.block_table = self._make_buffer(max_num_reqs,
                                             max_num_blocks_per_req,
                                             dtype=torch.int32)
37
38
        self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)

39
40
        self.slot_mapping = self._make_buffer(self.max_num_batched_tokens,
                                              dtype=torch.int64)
41
42
43
44
45
46
47
        try:
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
48

49
50
    def append_row(
        self,
51
        block_ids: list[int],
52
        row_idx: int,
53
    ) -> None:
54
55
        if not block_ids:
            return
56
        num_blocks = len(block_ids)
57
58
        start = self.num_blocks_per_row[row_idx]
        self.num_blocks_per_row[row_idx] += num_blocks
59
        self.block_table.np[row_idx, start:start + num_blocks] = block_ids
60

61
    def add_row(self, block_ids: list[int], row_idx: int) -> None:
62
63
        self.num_blocks_per_row[row_idx] = 0
        self.append_row(block_ids, row_idx)
64
65
66

    def move_row(self, src: int, tgt: int) -> None:
        num_blocks = self.num_blocks_per_row[src]
67
68
        block_table_np = self.block_table.np
        block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
69
70
        self.num_blocks_per_row[tgt] = num_blocks

71
    def swap_row(self, src: int, tgt: int) -> None:
72
73
74
        src_tgt, tgt_src = [src, tgt], [tgt, src]
        self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
        self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
75

76
77
78
79
80
81
82
83
    def compute_slot_mapping(self, req_indices: np.ndarray,
                             positions: np.ndarray) -> None:
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size`
        # here because M (max_model_len) is not necessarily divisible by
        # block_size.
84
        if self.dcp_world_size > 1:
85
            # Note(hc): The DCP implement store kvcache with an interleave
86
87
88
89
90
91
92
93
            # style, the kvcache for the token whose token_idx is i is
            # always stored on the GPU whose dcp_rank equals i % cp_world_size:

            # Use a "virtual block" which equals to world_size * block_size
            # for block_table_indices calculation.
            virtual_block_size = self.block_size * self.dcp_world_size
            block_table_indices = (req_indices * self.max_num_blocks_per_req +
                                   positions // virtual_block_size)
94
            block_numbers = self.block_table.np.ravel()[block_table_indices]
95
96
97
98
            # Use virtual_block_size for mask calculation, which marks local
            # tokens.
            virtual_block_offsets = positions % virtual_block_size
            mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
co63oc's avatar
co63oc committed
99
            # Calculate local block_offsets
100
            block_offsets = virtual_block_offsets // self.dcp_world_size
co63oc's avatar
co63oc committed
101
            # Calculate slot_mapping
102
103
            slot_mapping = block_numbers * self.block_size + block_offsets
            # Write final slots, use -1 for not-local
104
            self.slot_mapping.np[:req_indices.shape[0]] = np.where(
105
106
107
108
                mask, slot_mapping, -1)
        else:
            block_table_indices = (req_indices * self.max_num_blocks_per_req +
                                   positions // self.block_size)
109
            block_numbers = self.block_table.np.ravel()[block_table_indices]
110
111
112
            block_offsets = positions % self.block_size
            np.add(block_numbers * self.block_size,
                   block_offsets,
113
                   out=self.slot_mapping.np[:req_indices.shape[0]])
114
115

    def commit_block_table(self, num_reqs: int) -> None:
116
        self.block_table.copy_to_gpu(num_reqs)
117

118
    def commit_slot_mapping(self, num_tokens: int) -> None:
119
        self.slot_mapping.copy_to_gpu(num_tokens)
120

121
    def clear(self) -> None:
122
123
        self.block_table.gpu.fill_(0)
        self.block_table.cpu.fill_(0)
124

125
    def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
126
        """Returns the device tensor of the block table."""
127
        return self.block_table.gpu[:num_reqs]
128
129
130

    def get_cpu_tensor(self) -> torch.Tensor:
        """Returns the CPU tensor of the block table."""
131
        return self.block_table.cpu
132
133
134

    def get_numpy_array(self) -> np.ndarray:
        """Returns the numpy array of the block table."""
135
136
137
138
139
140
141
142
        return self.block_table.np

    def _make_buffer(self, *size: Union[int, torch.SymInt],
                     dtype: torch.dtype) -> CpuGpuBuffer:
        return CpuGpuBuffer(*size,
                            dtype=dtype,
                            device=self.device,
                            pin_memory=self.pin_memory)
143
144
145
146
147


class MultiGroupBlockTable:
    """The BlockTables for each KV cache group."""

148
149
150
151
152
153
154
155
    def __init__(self,
                 max_num_reqs: int,
                 max_model_len: int,
                 max_num_batched_tokens: int,
                 pin_memory: bool,
                 device: torch.device,
                 block_sizes: list[int],
                 num_speculative_tokens: int = 0) -> None:
156
157
158
159
160
161
162
163
164
165
        # Note(hc): each dcp rank only store
        # (max_model_len//dcp_world_size) tokens in kvcache,
        # so the block_size which used for calc max_num_blocks_per_req
        # must be multiplied by dcp_world_size.
        try:
            dcp_world_size = get_dcp_group().world_size
        except AssertionError:
            # DCP might not be initialized in testing
            dcp_world_size = 1

166
        self.block_tables = [
167
168
169
170
171
            BlockTable(
                block_size, max_num_reqs,
                max(cdiv(max_model_len, block_size * dcp_world_size),
                    1 + num_speculative_tokens), max_num_batched_tokens,
                pin_memory, device) for block_size in block_sizes
172
173
        ]

174
175
    def append_row(self, block_ids: tuple[list[int], ...],
                   row_idx: int) -> None:
176
177
178
        for i, block_table in enumerate(self.block_tables):
            block_table.append_row(block_ids[i], row_idx)

179
    def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
180
181
182
183
184
185
186
187
188
189
190
        for i, block_table in enumerate(self.block_tables):
            block_table.add_row(block_ids[i], row_idx)

    def move_row(self, src: int, tgt: int) -> None:
        for block_table in self.block_tables:
            block_table.move_row(src, tgt)

    def swap_row(self, src: int, tgt: int) -> None:
        for block_table in self.block_tables:
            block_table.swap_row(src, tgt)

191
192
193
194
195
196
197
198
199
200
    def compute_slot_mapping(self, req_indices: np.ndarray,
                             positions: np.ndarray) -> None:
        for block_table in self.block_tables:
            block_table.compute_slot_mapping(req_indices, positions)

    def commit_block_table(self, num_reqs: int) -> None:
        for block_table in self.block_tables:
            block_table.commit_block_table(num_reqs)

    def commit_slot_mapping(self, num_tokens: int) -> None:
201
        for block_table in self.block_tables:
202
            block_table.commit_slot_mapping(num_tokens)
203
204
205
206
207
208
209
210

    def clear(self) -> None:
        for block_table in self.block_tables:
            block_table.clear()

    def __getitem__(self, idx: int) -> "BlockTable":
        """Returns the BlockTable for the i-th KV cache group."""
        return self.block_tables[idx]