block_manager.py 18.9 KB
Newer Older
1
"""A block manager that manages token blocks."""
2
import enum
3
4
from itertools import count
from os.path import commonprefix
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
5
from typing import Dict, List, Optional, Set, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
6

7
from vllm.block import BlockTable, PhysicalTokenBlock
Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
10
from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12

Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
13
class BlockAllocator:
14
15
16
17
18
19
    """Manages free physical token blocks for a device.

    The allocator maintains a list of free blocks and allocates a block when
    requested. When a block is freed, its reference count is decremented. If
    the reference count becomes zero, the block is added back to the free list.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
20

21
22
23
24
25
26
    def __init__(self,
                 device: Device,
                 block_size: int,
                 num_blocks: int,
                 eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
                 enable_caching: bool = False) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
29
        self.device = device
        self.block_size = block_size
        self.num_blocks = num_blocks
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        self.enable_caching = enable_caching

        self.current_num_blocks = 0
        self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}

        # Switch over to FIFO eviction when caching is disabled
        if not self.enable_caching:
            eviction_policy = EvictionPolicy.FIFO
        self.evictor: Evictor = make_evictor(eviction_policy)

        self.default_hash_ctr = count()

    def allocate_block(self, block_hash: int,
                       num_hashed_tokens: int) -> PhysicalTokenBlock:
        if self.current_num_blocks == self.num_blocks:
            block = self.evictor.evict()
            block.block_hash = block_hash
            block.num_hashed_tokens = num_hashed_tokens
            return block
        block = PhysicalTokenBlock(device=self.device,
                                   block_number=self.current_num_blocks,
                                   block_size=self.block_size,
                                   block_hash=block_hash,
                                   num_hashed_tokens=num_hashed_tokens)
        self.current_num_blocks += 1
        return block
Woosuk Kwon's avatar
Woosuk Kwon committed
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    def allocate(self,
                 block_hash: Optional[int] = None,
                 num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
        # If caching is disabled, just allocate a new block and return it
        if not self.enable_caching:
            block = self.allocate_block(next(self.default_hash_ctr),
                                        num_hashed_tokens)
            block.ref_count += 1
            return block

        if block_hash is None:
            block_hash = next(self.default_hash_ctr)
        if block_hash in self.evictor:
            assert block_hash not in self.cached_blocks
            block = self.evictor.remove(block_hash)
            assert block.ref_count == 0
            self.cached_blocks[block_hash] = block
            block.ref_count += 1
            assert block.block_hash == block_hash
            return block
        if block_hash not in self.cached_blocks:
            self.cached_blocks[block_hash] = self.allocate_block(
                block_hash, num_hashed_tokens)
        block = self.cached_blocks[block_hash]
        assert block.block_hash == block_hash
        block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
85
86
        return block

    def free(self, block: PhysicalTokenBlock) -> None:
        if block.ref_count == 0:
87
            raise ValueError(f"Double free! {block} is already freed.")
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
        block.ref_count -= 1
        if block.ref_count == 0:
90
91
92
93
94
95
            assert block.block_hash not in self.evictor
            self.evictor.add(block)

            # If caching is enabled, remove the block from the cached_blocks
            if self.enable_caching:
                del self.cached_blocks[block.block_hash]
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97

    def get_num_free_blocks(self) -> int:
98
99
100
101
102
103
104
105
106
107
108
109
110
        return self.num_blocks - self.current_num_blocks + self.evictor.num_blocks

    def contains_block(self, block_hash: int) -> bool:
        return block_hash in self.cached_blocks or block_hash in self.evictor

    def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
        # If caching is enabled, update the hash of block and the cached_blocks dictionary.
        if self.enable_caching:
            assert not self.contains_block(block_hash)
            old_hash = block.block_hash
            block.block_hash = block_hash
            del self.cached_blocks[old_hash]
            self.cached_blocks[block_hash] = block
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112


113
114
115
116
117
118
119
120
121
122
123
124
125
126
class AllocStatus(enum.Enum):
    """Result for BlockSpaceManager.can_allocate

    1. Ok: seq_group can be allocated now.
    2. Later: seq_group cannot be allocated.
      The capacity of allocator is larger than seq_group required.
    3. Never: seq_group can never be allocated.
      The seq_group is too large to allocated in GPU.
    """
    OK = enum.auto()
    LATER = enum.auto()
    NEVER = enum.auto()


Woosuk Kwon's avatar
Woosuk Kwon committed
127
class BlockSpaceManager:
128
    """Manages the mapping between logical and physical token blocks."""
Woosuk Kwon's avatar
Woosuk Kwon committed
129
130
131
132
133
134

    def __init__(
        self,
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
135
        watermark: float = 0.01,
136
        sliding_window: Optional[int] = None,
137
        enable_caching: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
138
139
140
141
    ) -> None:
        self.block_size = block_size
        self.num_total_gpu_blocks = num_gpu_blocks
        self.num_total_cpu_blocks = num_cpu_blocks
142
143
144
145
146
147
148

        self.block_sliding_window = None
        if sliding_window is not None:
            assert sliding_window % block_size == 0, (sliding_window,
                                                      block_size)
            self.block_sliding_window = sliding_window // block_size

149
150
        self.watermark = watermark
        assert watermark >= 0.0
Woosuk Kwon's avatar
Woosuk Kwon committed
151

152
153
        self.enable_caching = enable_caching

154
        self.watermark_blocks = int(watermark * num_gpu_blocks)
155
156
157
158
159
160
161
162
        self.gpu_allocator = BlockAllocator(Device.GPU,
                                            block_size,
                                            num_gpu_blocks,
                                            enable_caching=enable_caching)
        self.cpu_allocator = BlockAllocator(Device.CPU,
                                            block_size,
                                            num_cpu_blocks,
                                            enable_caching=enable_caching)
Woosuk Kwon's avatar
Woosuk Kwon committed
163
164
165
        # Mapping: seq_id -> BlockTable.
        self.block_tables: Dict[int, BlockTable] = {}

166
    def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
167
168
        # FIXME(woosuk): Here we assume that all sequences in the group share
        # the same prompt. This may not be true for preempted sequences.
169
        seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
170
        num_required_blocks = len(seq.logical_token_blocks)
171

172
173
174
        if self.block_sliding_window is not None:
            num_required_blocks = min(num_required_blocks,
                                      self.block_sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
175
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
176

177
        # Use watermark to avoid frequent cache eviction.
178
179
180
181
182
183
184
        if (self.num_total_gpu_blocks - num_required_blocks <
                self.watermark_blocks):
            return AllocStatus.NEVER
        if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
            return AllocStatus.OK
        else:
            return AllocStatus.LATER
Woosuk Kwon's avatar
Woosuk Kwon committed
185
186

    def allocate(self, seq_group: SequenceGroup) -> None:
187
188
        # NOTE: Here we assume that all sequences in the group have the same
        # prompt.
189
        seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191

        # Allocate new physical token blocks that will store the prompt tokens.
192
193
        num_prompt_blocks = len(seq.logical_token_blocks)

Woosuk Kwon's avatar
Woosuk Kwon committed
194
        block_table: BlockTable = []
195
        for logical_idx in range(num_prompt_blocks):
196
197
198
199
            if (self.block_sliding_window is not None
                    and logical_idx >= self.block_sliding_window):
                block = block_table[logical_idx % self.block_sliding_window]
            else:
200
201
202
                block = self.gpu_allocator.allocate(
                    seq.hash_of_block(logical_idx),
                    seq.num_hashed_tokens_of_block(logical_idx))
Woosuk Kwon's avatar
Woosuk Kwon committed
203
204
205
            block_table.append(block)

        # Assign the block table for each sequence.
206
        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
Woosuk Kwon's avatar
Woosuk Kwon committed
207
208
            self.block_tables[seq.seq_id] = block_table.copy()

209
    def can_append_slot(self, seq_group: SequenceGroup) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211
212
        # Simple heuristic: If there is at least one free block
        # for each sequence, we can append.
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
Woosuk Kwon's avatar
Woosuk Kwon committed
213
        num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
        return num_seqs <= num_free_gpu_blocks

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    def _promote_last_block(
        self,
        seq: Sequence,
        last_block: PhysicalTokenBlock,
    ) -> PhysicalTokenBlock:
        # Compute a new hash for the block so that it can be shared by other Sequences
        new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)

        # if new_hash is already in the cached table, then free last_block and return the cached version
        if self.gpu_allocator.contains_block(new_hash):
            self.gpu_allocator.free(last_block)
            return self.gpu_allocator.allocate(new_hash)
        else:
            self.gpu_allocator.update_hash(new_hash, last_block)
            return last_block

    def _is_last_block_full(
        self,
        seq: Sequence,
    ) -> bool:
        token_ids_len = len(seq.data.get_token_ids())
        return token_ids_len > 0 and token_ids_len % seq.block_size == 0

    def _maybe_promote_last_block(
        self,
        seq: Sequence,
        last_block: PhysicalTokenBlock,
    ) -> PhysicalTokenBlock:
        if self._is_last_block_full(seq):
            return self._promote_last_block(seq, last_block)
        else:
            return last_block

    def _allocate_last_physical_block(
        self,
        seq: Sequence,
    ) -> PhysicalTokenBlock:
        block_hash: Optional[int] = None
        if (self._is_last_block_full(seq)):
            block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
        num_hashed_tokens = seq.num_hashed_tokens_of_block(
            len(seq.logical_token_blocks) - 1)
        new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
        if block_hash is None:
            assert new_block.ref_count == 1
        return new_block

    def append_slot(
        self,
        seq: Sequence,
    ) -> Optional[Tuple[int, int]]:
267
        """Allocate a physical slot for a new token."""
Woosuk Kwon's avatar
Woosuk Kwon committed
268
269
        logical_blocks = seq.logical_token_blocks
        block_table = self.block_tables[seq.seq_id]
270
        # If we need to allocate a new physical block
Woosuk Kwon's avatar
Woosuk Kwon committed
271
        if len(block_table) < len(logical_blocks):
272
273
274
            # Currently this code only supports adding one physical block
            assert len(block_table) == len(logical_blocks) - 1

275
276
            if (self.block_sliding_window
                    and len(block_table) >= self.block_sliding_window):
277
                # reuse a block
278
279
280
281
282
                block_table.append(block_table[len(block_table) %
                                               self.block_sliding_window])
            else:
                # The sequence has a new logical block.
                # Allocate a new physical block.
283
284
                new_block = self._allocate_last_physical_block(seq)
                block_table.append(new_block)
285
                return None
Woosuk Kwon's avatar
Woosuk Kwon committed
286
287
288
289
290

        # We want to append the token to the last physical block.
        last_block = block_table[-1]
        assert last_block.device == Device.GPU
        if last_block.ref_count == 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
291
            # Not shared with other sequences. Appendable.
292
293
294
            # If the last block is now complete, promote it to a full block so that it can be shared
            new_block = self._maybe_promote_last_block(seq, last_block)
            block_table[-1] = new_block
Woosuk Kwon's avatar
Woosuk Kwon committed
295
296
297
298
            return None
        else:
            # The last block is shared with other sequences.
            # Copy on Write: Allocate a new block and copy the tokens.
299
300
            new_block = self._allocate_last_physical_block(seq)

Woosuk Kwon's avatar
Woosuk Kwon committed
301
            block_table[-1] = new_block
Woosuk Kwon's avatar
Woosuk Kwon committed
302
            self.gpu_allocator.free(last_block)
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
303
            return last_block.block_number, new_block.block_number
Woosuk Kwon's avatar
Woosuk Kwon committed
304

Woosuk Kwon's avatar
Woosuk Kwon committed
305
    def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
306
307
        # NOTE: fork does not allocate a new physical block.
        # Thus, it is always safe from OOM.
Woosuk Kwon's avatar
Woosuk Kwon committed
308
        src_block_table = self.block_tables[parent_seq.seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
309
310
311
312
        self.block_tables[child_seq.seq_id] = src_block_table.copy()
        for block in src_block_table:
            block.ref_count += 1

313
314
    def _get_physical_blocks(
            self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
Woosuk Kwon's avatar
Woosuk Kwon committed
315
316
317
        # NOTE: Here, we assume that the physical blocks are only shared by
        # the sequences in the same group.
        blocks: Set[PhysicalTokenBlock] = set()
318
        for seq in seq_group.get_seqs():
319
            if seq.is_finished():
Woosuk Kwon's avatar
Woosuk Kwon committed
320
                continue
321
            blocks.update(self.block_tables[seq.seq_id])
Woosuk Kwon's avatar
Woosuk Kwon committed
322
323
324
325
        return list(blocks)

    def can_swap_in(self, seq_group: SequenceGroup) -> bool:
        blocks = self._get_physical_blocks(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
326
        num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
Woosuk Kwon's avatar
Woosuk Kwon committed
327
328
329
        num_free_blocks = self.gpu_allocator.get_num_free_blocks()
        # NOTE: Conservatively, we assume that every sequence will allocate
        # at least one free block right after the swap-in.
330
        # NOTE: This should match the logic in can_append_slot().
331
332
        num_required_blocks = len(blocks) + num_swapped_seqs
        return num_free_blocks - num_required_blocks >= self.watermark_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
333
334

    def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
335
336
        # CPU block -> GPU block.
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
337
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
338
            new_block_table: BlockTable = []
Woosuk Kwon's avatar
Woosuk Kwon committed
339
340
341
342
            block_table = self.block_tables[seq.seq_id]

            for cpu_block in block_table:
                if cpu_block in mapping:
Woosuk Kwon's avatar
Woosuk Kwon committed
343
                    gpu_block = mapping[cpu_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
344
                    gpu_block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
345
                else:
346
347
                    gpu_block = self.gpu_allocator.allocate(
                        cpu_block.block_hash, cpu_block.num_hashed_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
348
349
                    mapping[cpu_block] = gpu_block
                new_block_table.append(gpu_block)
Woosuk Kwon's avatar
Woosuk Kwon committed
350
351
                # Free the CPU block swapped in to GPU.
                self.cpu_allocator.free(cpu_block)
352
353
354
355
356
357
358
            self.block_tables[seq.seq_id] = new_block_table

        block_number_mapping = {
            cpu_block.block_number: gpu_block.block_number
            for cpu_block, gpu_block in mapping.items()
        }
        return block_number_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
359
360
361
362
363
364

    def can_swap_out(self, seq_group: SequenceGroup) -> bool:
        blocks = self._get_physical_blocks(seq_group)
        return len(blocks) <= self.cpu_allocator.get_num_free_blocks()

    def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
365
366
        # GPU block -> CPU block.
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
367
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
368
            new_block_table: BlockTable = []
Woosuk Kwon's avatar
Woosuk Kwon committed
369
370
371
            block_table = self.block_tables[seq.seq_id]

            for gpu_block in block_table:
372
                if gpu_block in mapping:
Woosuk Kwon's avatar
Woosuk Kwon committed
373
                    cpu_block = mapping[gpu_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
374
                    cpu_block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
375
                else:
376
377
                    cpu_block = self.cpu_allocator.allocate(
                        gpu_block.block_hash, gpu_block.num_hashed_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
378
                    mapping[gpu_block] = cpu_block
379
                new_block_table.append(cpu_block)
Woosuk Kwon's avatar
Woosuk Kwon committed
380
381
                # Free the GPU block swapped out to CPU.
                self.gpu_allocator.free(gpu_block)
382
383
384
385
386
387
388
            self.block_tables[seq.seq_id] = new_block_table

        block_number_mapping = {
            gpu_block.block_number: cpu_block.block_number
            for gpu_block, cpu_block in mapping.items()
        }
        return block_number_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
389

Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
390
    def _free_block_table(self, block_table: BlockTable) -> None:
391
        for block in set(block_table):
Woosuk Kwon's avatar
Woosuk Kwon committed
392
393
394
395
396
397
            if block.device == Device.GPU:
                self.gpu_allocator.free(block)
            else:
                self.cpu_allocator.free(block)

    def free(self, seq: Sequence) -> None:
398
399
400
        if seq.seq_id not in self.block_tables:
            # Already freed or haven't been scheduled yet.
            return
Woosuk Kwon's avatar
Woosuk Kwon committed
401
        block_table = self.block_tables[seq.seq_id]
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
402
        self._free_block_table(block_table)
Woosuk Kwon's avatar
Woosuk Kwon committed
403
404
405
406
        del self.block_tables[seq.seq_id]

    def reset(self) -> None:
        for block_table in self.block_tables.values():
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
407
            self._free_block_table(block_table)
Woosuk Kwon's avatar
Woosuk Kwon committed
408
        self.block_tables.clear()
Woosuk Kwon's avatar
Woosuk Kwon committed
409
410
411
412

    def get_block_table(self, seq: Sequence) -> List[int]:
        block_table = self.block_tables[seq.seq_id]
        return [block.block_number for block in block_table]
413
414
415
416
417
418

    def get_num_free_gpu_blocks(self) -> int:
        return self.gpu_allocator.get_num_free_blocks()

    def get_num_free_cpu_blocks(self) -> int:
        return self.cpu_allocator.get_num_free_blocks()
419
420
421
422
423
424
425
426
427
428
429
430
431

    def access_all_blocks_in_seq(
        self,
        seq: Sequence,
        access_time: float,
    ) -> None:
        block_table = self.block_tables[seq.seq_id]
        for block in block_table:
            block.last_accessed = access_time

    def compute_last_full_block_in_seq(self, seq: Sequence):
        if seq.seq_id not in self.block_tables:
            return
432
        max_full_block = seq.get_len() // self.block_size - 1
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        block_table = self.block_tables[seq.seq_id]
        if max_full_block == -1:
            return
        block_table[max_full_block].computed = True

    def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]:
        if seq.seq_id not in self.block_tables:
            return []
        block_table = self.block_tables[seq.seq_id]
        for block_idx in reversed(range(len(block_table))):
            if block_table[block_idx].computed:
                return [b.block_number for b in block_table[:block_idx + 1]]
        return []

    def get_common_computed_block_ids(self,
                                      seq_group: SequenceGroup) -> List[int]:
449
        # Can return non-empty result only with prefix caching enabled.
450
451
452
453
454
455
456
457
458
459
        if not self.enable_caching:
            return []

        ids_list = [
            self.get_all_block_ids_till_computed(seq)
            for seq in iter(seq_group.seqs_dict.values())
        ]
        return commonprefix([ids for ids in ids_list if ids != []])

    def mark_blocks_as_computed(self, seq_group: SequenceGroup):
460
461
        # NOTE: We only mark the last full block because with prefix caching,
        # all blocks until the marked one are guaranteed to be computed.
462
463
464
        if self.enable_caching:
            for seq in seq_group.seqs_dict.values():
                self.compute_last_full_block_in_seq(seq)