block_manager.py 13.4 KB
Newer Older
1
"""A block manager that manages token blocks."""
2
import enum
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
3
from typing import Dict, List, Optional, Set, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
from vllm.block import BlockTable, PhysicalTokenBlock
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9

Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
10
class BlockAllocator:
11
12
13
14
15
16
    """Manages free physical token blocks for a device.

    The allocator maintains a list of free blocks and allocates a block when
    requested. When a block is freed, its reference count is decremented. If
    the reference count becomes zero, the block is added back to the free list.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
19
20
21
22
23
24
25
26
27
28

    def __init__(
        self,
        device: Device,
        block_size: int,
        num_blocks: int,
    ) -> None:
        self.device = device
        self.block_size = block_size
        self.num_blocks = num_blocks

        # Initialize the free blocks.
29
        self.free_blocks: BlockTable = []
30
        for i in range(num_blocks):
31
32
33
            block = PhysicalTokenBlock(device=device,
                                       block_number=i,
                                       block_size=block_size)
34
            self.free_blocks.append(block)
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37

    def allocate(self) -> PhysicalTokenBlock:
        if not self.free_blocks:
38
            raise ValueError("Out of memory! No free blocks are available.")
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
42
43
44
        block = self.free_blocks.pop()
        block.ref_count = 1
        return block

    def free(self, block: PhysicalTokenBlock) -> None:
        if block.ref_count == 0:
45
            raise ValueError(f"Double free! {block} is already freed.")
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
50
51
52
53
        block.ref_count -= 1
        if block.ref_count == 0:
            self.free_blocks.append(block)

    def get_num_free_blocks(self) -> int:
        return len(self.free_blocks)


54
55
56
57
58
59
60
61
62
63
64
65
66
67
class AllocStatus(enum.Enum):
    """Result for BlockSpaceManager.can_allocate

    1. Ok: seq_group can be allocated now.
    2. Later: seq_group cannot be allocated.
      The capacity of allocator is larger than seq_group required.
    3. Never: seq_group can never be allocated.
      The seq_group is too large to allocated in GPU.
    """
    OK = enum.auto()
    LATER = enum.auto()
    NEVER = enum.auto()


Woosuk Kwon's avatar
Woosuk Kwon committed
68
class BlockSpaceManager:
69
    """Manages the mapping between logical and physical token blocks."""
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
72
73
74
75

    def __init__(
        self,
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
76
        watermark: float = 0.01,
77
        sliding_window: Optional[int] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
    ) -> None:
        self.block_size = block_size
        self.num_total_gpu_blocks = num_gpu_blocks
        self.num_total_cpu_blocks = num_cpu_blocks
82
83
84
85
86
87
88

        self.block_sliding_window = None
        if sliding_window is not None:
            assert sliding_window % block_size == 0, (sliding_window,
                                                      block_size)
            self.block_sliding_window = sliding_window // block_size

89
90
        self.watermark = watermark
        assert watermark >= 0.0
Woosuk Kwon's avatar
Woosuk Kwon committed
91

92
        self.watermark_blocks = int(watermark * num_gpu_blocks)
93
94
95
96
        self.gpu_allocator = BlockAllocator(Device.GPU, block_size,
                                            num_gpu_blocks)
        self.cpu_allocator = BlockAllocator(Device.CPU, block_size,
                                            num_cpu_blocks)
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
        # Mapping: seq_id -> BlockTable.
        self.block_tables: Dict[int, BlockTable] = {}

100
    def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
101
102
        # FIXME(woosuk): Here we assume that all sequences in the group share
        # the same prompt. This may not be true for preempted sequences.
103
        seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
104
        num_required_blocks = len(seq.logical_token_blocks)
105
106
107
108

        if seq_group.prefix is not None and seq_group.prefix.allocated:
            num_required_blocks -= seq_group.prefix.get_num_blocks()

109
110
111
        if self.block_sliding_window is not None:
            num_required_blocks = min(num_required_blocks,
                                      self.block_sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
113

114
        # Use watermark to avoid frequent cache eviction.
115
116
117
118
119
120
121
        if (self.num_total_gpu_blocks - num_required_blocks <
                self.watermark_blocks):
            return AllocStatus.NEVER
        if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
            return AllocStatus.OK
        else:
            return AllocStatus.LATER
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123

    def allocate(self, seq_group: SequenceGroup) -> None:
124
125
        # NOTE: Here we assume that all sequences in the group have the same
        # prompt.
126
        seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128

        # Allocate new physical token blocks that will store the prompt tokens.
129
130
        num_prompt_blocks = len(seq.logical_token_blocks)

Woosuk Kwon's avatar
Woosuk Kwon committed
131
        block_table: BlockTable = []
132
133
134
135
136
137
138
139
140
141
142
143
        prefix_block_table: BlockTable = []
        num_prefix_blocks = 0

        prefix = seq_group.prefix
        if prefix is not None and prefix.allocated:
            # Prefix has already been allocated. Use the existing block table.
            num_prompt_blocks -= prefix.get_num_blocks()
            for block in prefix.block_table:
                block.ref_count += seq_group.num_seqs()
                block_table.append(block)

        for logical_idx in range(num_prompt_blocks):
144
145
146
147
148
            if (self.block_sliding_window is not None
                    and logical_idx >= self.block_sliding_window):
                block = block_table[logical_idx % self.block_sliding_window]
            else:
                block = self.gpu_allocator.allocate()
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151
152
            # Set the reference counts of the token blocks.
            block.ref_count = seq_group.num_seqs()
            block_table.append(block)

153
154
155
156
157
158
159
160
161
        if prefix is not None and not prefix.allocated:
            # Allocate blocks for the prefix, we will compute the prefix's
            # KV cache in this run.
            num_prefix_blocks = prefix.get_num_blocks()
            prefix_block_table = block_table[:num_prefix_blocks]
            for block in prefix_block_table:
                block.ref_count += 1
            prefix.set_block_table(prefix_block_table)

Woosuk Kwon's avatar
Woosuk Kwon committed
162
        # Assign the block table for each sequence.
163
        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
            self.block_tables[seq.seq_id] = block_table.copy()

166
    def can_append_slot(self, seq_group: SequenceGroup) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
167
168
169
        # Simple heuristic: If there is at least one free block
        # for each sequence, we can append.
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
Woosuk Kwon's avatar
Woosuk Kwon committed
170
        num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
Woosuk Kwon's avatar
Woosuk Kwon committed
171
172
        return num_seqs <= num_free_gpu_blocks

173
174
    def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
        """Allocate a physical slot for a new token."""
Woosuk Kwon's avatar
Woosuk Kwon committed
175
176
177
178
        logical_blocks = seq.logical_token_blocks
        block_table = self.block_tables[seq.seq_id]

        if len(block_table) < len(logical_blocks):
179
180
181
182
183
184
185
186
187
188
189
            if (self.block_sliding_window
                    and len(block_table) >= self.block_sliding_window):
                # re-use a block
                block_table.append(block_table[len(block_table) %
                                               self.block_sliding_window])
            else:
                # The sequence has a new logical block.
                # Allocate a new physical block.
                block = self.gpu_allocator.allocate()
                block_table.append(block)
                return None
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
192
193
194

        # We want to append the token to the last physical block.
        last_block = block_table[-1]
        assert last_block.device == Device.GPU
        if last_block.ref_count == 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
195
            # Not shared with other sequences. Appendable.
Woosuk Kwon's avatar
Woosuk Kwon committed
196
197
198
199
            return None
        else:
            # The last block is shared with other sequences.
            # Copy on Write: Allocate a new block and copy the tokens.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
200
            new_block = self.gpu_allocator.allocate()
Woosuk Kwon's avatar
Woosuk Kwon committed
201
            block_table[-1] = new_block
Woosuk Kwon's avatar
Woosuk Kwon committed
202
            self.gpu_allocator.free(last_block)
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
203
            return last_block.block_number, new_block.block_number
Woosuk Kwon's avatar
Woosuk Kwon committed
204

Woosuk Kwon's avatar
Woosuk Kwon committed
205
    def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
        # NOTE: fork does not allocate a new physical block.
        # Thus, it is always safe from OOM.
Woosuk Kwon's avatar
Woosuk Kwon committed
208
        src_block_table = self.block_tables[parent_seq.seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211
212
        self.block_tables[child_seq.seq_id] = src_block_table.copy()
        for block in src_block_table:
            block.ref_count += 1

213
214
    def _get_physical_blocks(
            self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
Woosuk Kwon's avatar
Woosuk Kwon committed
215
216
217
        # NOTE: Here, we assume that the physical blocks are only shared by
        # the sequences in the same group.
        blocks: Set[PhysicalTokenBlock] = set()
218
        for seq in seq_group.get_seqs():
219
            if seq.is_finished():
Woosuk Kwon's avatar
Woosuk Kwon committed
220
                continue
221
            blocks.update(self.block_tables[seq.seq_id])
Woosuk Kwon's avatar
Woosuk Kwon committed
222
223
224
225
        return list(blocks)

    def can_swap_in(self, seq_group: SequenceGroup) -> bool:
        blocks = self._get_physical_blocks(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
226
        num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
Woosuk Kwon's avatar
Woosuk Kwon committed
227
228
229
        num_free_blocks = self.gpu_allocator.get_num_free_blocks()
        # NOTE: Conservatively, we assume that every sequence will allocate
        # at least one free block right after the swap-in.
230
        # NOTE: This should match the logic in can_append_slot().
231
232
        num_required_blocks = len(blocks) + num_swapped_seqs
        return num_free_blocks - num_required_blocks >= self.watermark_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234

    def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
235
        # CPU block -> GPU block.
236
237
238
239
        if seq_group.prefix is not None:
            # make sure to swap in the prefix first
            assert seq_group.prefix.allocated and seq_group.prefix.computed

240
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
241
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
242
            new_block_table: BlockTable = []
Woosuk Kwon's avatar
Woosuk Kwon committed
243
            block_table = self.block_tables[seq.seq_id]
244
245
246
247
            if seq_group.prefix is not None:
                for block in seq_group.prefix.block_table:
                    new_block_table.append(block)
                    block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
248
249
250

            for cpu_block in block_table:
                if cpu_block in mapping:
Woosuk Kwon's avatar
Woosuk Kwon committed
251
                    gpu_block = mapping[cpu_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
252
                    gpu_block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
253
254
255
256
                else:
                    gpu_block = self.gpu_allocator.allocate()
                    mapping[cpu_block] = gpu_block
                new_block_table.append(gpu_block)
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
                # Free the CPU block swapped in to GPU.
                self.cpu_allocator.free(cpu_block)
259
260
261
262
263
264
265
            self.block_tables[seq.seq_id] = new_block_table

        block_number_mapping = {
            cpu_block.block_number: gpu_block.block_number
            for cpu_block, gpu_block in mapping.items()
        }
        return block_number_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
266
267
268
269
270
271

    def can_swap_out(self, seq_group: SequenceGroup) -> bool:
        blocks = self._get_physical_blocks(seq_group)
        return len(blocks) <= self.cpu_allocator.get_num_free_blocks()

    def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
272
273
        # GPU block -> CPU block.
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
274
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
275
            new_block_table: BlockTable = []
Woosuk Kwon's avatar
Woosuk Kwon committed
276
277
278
            block_table = self.block_tables[seq.seq_id]

            for gpu_block in block_table:
279
280
281
282
283
284
                if (seq_group.prefix is not None
                        and gpu_block in seq_group.prefix.block_table):
                    # NOTE: We do not swap out the prefix blocks for now.
                    self.gpu_allocator.free(gpu_block)
                    continue

285
                if gpu_block in mapping:
Woosuk Kwon's avatar
Woosuk Kwon committed
286
                    cpu_block = mapping[gpu_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
287
                    cpu_block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
288
289
290
                else:
                    cpu_block = self.cpu_allocator.allocate()
                    mapping[gpu_block] = cpu_block
291
                new_block_table.append(cpu_block)
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
                # Free the GPU block swapped out to CPU.
                self.gpu_allocator.free(gpu_block)
294
295
296
297
298
299
300
            self.block_tables[seq.seq_id] = new_block_table

        block_number_mapping = {
            gpu_block.block_number: cpu_block.block_number
            for gpu_block, cpu_block in mapping.items()
        }
        return block_number_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
301

Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
302
    def _free_block_table(self, block_table: BlockTable) -> None:
303
        for block in set(block_table):
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
306
307
308
309
            if block.device == Device.GPU:
                self.gpu_allocator.free(block)
            else:
                self.cpu_allocator.free(block)

    def free(self, seq: Sequence) -> None:
310
311
312
        if seq.seq_id not in self.block_tables:
            # Already freed or haven't been scheduled yet.
            return
Woosuk Kwon's avatar
Woosuk Kwon committed
313
        block_table = self.block_tables[seq.seq_id]
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
314
        self._free_block_table(block_table)
Woosuk Kwon's avatar
Woosuk Kwon committed
315
316
317
318
        del self.block_tables[seq.seq_id]

    def reset(self) -> None:
        for block_table in self.block_tables.values():
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
319
            self._free_block_table(block_table)
Woosuk Kwon's avatar
Woosuk Kwon committed
320
        self.block_tables.clear()
Woosuk Kwon's avatar
Woosuk Kwon committed
321
322
323
324

    def get_block_table(self, seq: Sequence) -> List[int]:
        block_table = self.block_tables[seq.seq_id]
        return [block.block_number for block in block_table]
325
326
327
328
329
330

    def get_num_free_gpu_blocks(self) -> int:
        return self.gpu_allocator.get_num_free_blocks()

    def get_num_free_cpu_blocks(self) -> int:
        return self.cpu_allocator.get_num_free_blocks()