block_manager.py 10.2 KB
Newer Older
1
"""A block manager that manages token blocks."""
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
2
from typing import Dict, List, Optional, Set, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
6
from vllm.block import PhysicalTokenBlock
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8


Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
9
class BlockAllocator:
10
11
12
13
14
15
    """Manages free physical token blocks for a device.

    The allocator maintains a list of free blocks and allocates a block when
    requested. When a block is freed, its reference count is decremented. If
    the reference count becomes zero, the block is added back to the free list.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
20
21
22
23
24
25
26
27

    def __init__(
        self,
        device: Device,
        block_size: int,
        num_blocks: int,
    ) -> None:
        self.device = device
        self.block_size = block_size
        self.num_blocks = num_blocks

        # Initialize the free blocks.
28
29
        self.free_blocks: List[PhysicalTokenBlock] = []
        for i in range(num_blocks):
30
31
32
            block = PhysicalTokenBlock(device=device,
                                       block_number=i,
                                       block_size=block_size)
33
            self.free_blocks.append(block)
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36

    def allocate(self) -> PhysicalTokenBlock:
        if not self.free_blocks:
37
            raise ValueError("Out of memory! No free blocks are available.")
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
40
41
42
43
        block = self.free_blocks.pop()
        block.ref_count = 1
        return block

    def free(self, block: PhysicalTokenBlock) -> None:
        if block.ref_count == 0:
44
            raise ValueError(f"Double free! {block} is already freed.")
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46
47
48
49
50
51
52
53
54
55
56
57
        block.ref_count -= 1
        if block.ref_count == 0:
            self.free_blocks.append(block)

    def get_num_free_blocks(self) -> int:
        return len(self.free_blocks)


# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]


class BlockSpaceManager:
58
    """Manages the mapping between logical and physical token blocks."""
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
62
63
64

    def __init__(
        self,
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
65
        watermark: float = 0.01,
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
68
69
    ) -> None:
        self.block_size = block_size
        self.num_total_gpu_blocks = num_gpu_blocks
        self.num_total_cpu_blocks = num_cpu_blocks
70
71
        self.watermark = watermark
        assert watermark >= 0.0
Woosuk Kwon's avatar
Woosuk Kwon committed
72

73
        self.watermark_blocks = int(watermark * num_gpu_blocks)
74
75
76
77
        self.gpu_allocator = BlockAllocator(Device.GPU, block_size,
                                            num_gpu_blocks)
        self.cpu_allocator = BlockAllocator(Device.CPU, block_size,
                                            num_cpu_blocks)
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
        # Mapping: seq_id -> BlockTable.
        self.block_tables: Dict[int, BlockTable] = {}

    def can_allocate(self, seq_group: SequenceGroup) -> bool:
82
83
        # FIXME(woosuk): Here we assume that all sequences in the group share
        # the same prompt. This may not be true for preempted sequences.
84
        seq = seq_group.get_seqs()[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
        num_required_blocks = len(seq.logical_token_blocks)
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
87
        # Use watermark to avoid frequent cache eviction.
88
89
        return (num_free_gpu_blocks - num_required_blocks >=
                self.watermark_blocks)
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91

    def allocate(self, seq_group: SequenceGroup) -> None:
92
93
        # NOTE: Here we assume that all sequences in the group have the same
        # prompt.
94
        seq = seq_group.get_seqs()[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
98
99
100
101
102
103
104

        # Allocate new physical token blocks that will store the prompt tokens.
        block_table: BlockTable = []
        for _ in range(len(seq.logical_token_blocks)):
            block = self.gpu_allocator.allocate()
            # Set the reference counts of the token blocks.
            block.ref_count = seq_group.num_seqs()
            block_table.append(block)

        # Assign the block table for each sequence.
105
        for seq in seq_group.get_seqs():
Woosuk Kwon's avatar
Woosuk Kwon committed
106
107
            self.block_tables[seq.seq_id] = block_table.copy()

108
    def can_append_slot(self, seq_group: SequenceGroup) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
109
110
111
        # Simple heuristic: If there is at least one free block
        # for each sequence, we can append.
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
        return num_seqs <= num_free_gpu_blocks

115
116
    def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
        """Allocate a physical slot for a new token."""
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        logical_blocks = seq.logical_token_blocks
        block_table = self.block_tables[seq.seq_id]

        if len(block_table) < len(logical_blocks):
            # The sequence has a new logical block.
            # Allocate a new physical block.
            block = self.gpu_allocator.allocate()
            block_table.append(block)
            return None

        # We want to append the token to the last physical block.
        last_block = block_table[-1]
        assert last_block.device == Device.GPU
        if last_block.ref_count == 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
131
            # Not shared with other sequences. Appendable.
Woosuk Kwon's avatar
Woosuk Kwon committed
132
133
134
135
            return None
        else:
            # The last block is shared with other sequences.
            # Copy on Write: Allocate a new block and copy the tokens.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
136
            new_block = self.gpu_allocator.allocate()
Woosuk Kwon's avatar
Woosuk Kwon committed
137
            block_table[-1] = new_block
Woosuk Kwon's avatar
Woosuk Kwon committed
138
            self.gpu_allocator.free(last_block)
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
139
            return last_block.block_number, new_block.block_number
Woosuk Kwon's avatar
Woosuk Kwon committed
140

Woosuk Kwon's avatar
Woosuk Kwon committed
141
    def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
        # NOTE: fork does not allocate a new physical block.
        # Thus, it is always safe from OOM.
Woosuk Kwon's avatar
Woosuk Kwon committed
144
        src_block_table = self.block_tables[parent_seq.seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147
148
        self.block_tables[child_seq.seq_id] = src_block_table.copy()
        for block in src_block_table:
            block.ref_count += 1

149
150
    def _get_physical_blocks(
            self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
153
        # NOTE: Here, we assume that the physical blocks are only shared by
        # the sequences in the same group.
        blocks: Set[PhysicalTokenBlock] = set()
154
        for seq in seq_group.get_seqs():
155
            if seq.is_finished():
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
158
159
160
161
162
163
                continue
            block_table = self.block_tables[seq.seq_id]
            for block in block_table:
                blocks.add(block)
        return list(blocks)

    def can_swap_in(self, seq_group: SequenceGroup) -> bool:
        blocks = self._get_physical_blocks(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
164
        num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
Woosuk Kwon's avatar
Woosuk Kwon committed
165
166
167
        num_free_blocks = self.gpu_allocator.get_num_free_blocks()
        # NOTE: Conservatively, we assume that every sequence will allocate
        # at least one free block right after the swap-in.
168
        # NOTE: This should match the logic in can_append_slot().
169
170
        num_required_blocks = len(blocks) + num_swapped_seqs
        return num_free_blocks - num_required_blocks >= self.watermark_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
171
172

    def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
173
174
        # CPU block -> GPU block.
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
175
        for seq in seq_group.get_seqs():
176
            if seq.is_finished():
Woosuk Kwon's avatar
Woosuk Kwon committed
177
                continue
178
            new_block_table: BlockTable = []
Woosuk Kwon's avatar
Woosuk Kwon committed
179
180
181
182
            block_table = self.block_tables[seq.seq_id]

            for cpu_block in block_table:
                if cpu_block in mapping:
Woosuk Kwon's avatar
Woosuk Kwon committed
183
                    gpu_block = mapping[cpu_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
184
                    gpu_block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
185
186
187
188
                else:
                    gpu_block = self.gpu_allocator.allocate()
                    mapping[cpu_block] = gpu_block
                new_block_table.append(gpu_block)
Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
                # Free the CPU block swapped in to GPU.
                self.cpu_allocator.free(cpu_block)
191
192
193
194
195
196
197
            self.block_tables[seq.seq_id] = new_block_table

        block_number_mapping = {
            cpu_block.block_number: gpu_block.block_number
            for cpu_block, gpu_block in mapping.items()
        }
        return block_number_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
200
201
202
203

    def can_swap_out(self, seq_group: SequenceGroup) -> bool:
        blocks = self._get_physical_blocks(seq_group)
        return len(blocks) <= self.cpu_allocator.get_num_free_blocks()

    def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
204
205
        # GPU block -> CPU block.
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
206
        for seq in seq_group.get_seqs():
207
            if seq.is_finished():
Woosuk Kwon's avatar
Woosuk Kwon committed
208
                continue
209
            new_block_table: BlockTable = []
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211
212
            block_table = self.block_tables[seq.seq_id]

            for gpu_block in block_table:
213
                if gpu_block in mapping:
Woosuk Kwon's avatar
Woosuk Kwon committed
214
                    cpu_block = mapping[gpu_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
215
                    cpu_block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
218
                else:
                    cpu_block = self.cpu_allocator.allocate()
                    mapping[gpu_block] = cpu_block
219
                new_block_table.append(cpu_block)
Woosuk Kwon's avatar
Woosuk Kwon committed
220
221
                # Free the GPU block swapped out to CPU.
                self.gpu_allocator.free(gpu_block)
222
223
224
225
226
227
228
            self.block_tables[seq.seq_id] = new_block_table

        block_number_mapping = {
            gpu_block.block_number: cpu_block.block_number
            for gpu_block, cpu_block in mapping.items()
        }
        return block_number_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
229

Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
230
231
    def _free_block_table(self, block_table: BlockTable) -> None:
        for block in block_table:
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
234
235
236
237
            if block.device == Device.GPU:
                self.gpu_allocator.free(block)
            else:
                self.cpu_allocator.free(block)

    def free(self, seq: Sequence) -> None:
238
239
240
        if seq.seq_id not in self.block_tables:
            # Already freed or haven't been scheduled yet.
            return
Woosuk Kwon's avatar
Woosuk Kwon committed
241
        block_table = self.block_tables[seq.seq_id]
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
242
        self._free_block_table(block_table)
Woosuk Kwon's avatar
Woosuk Kwon committed
243
244
245
246
        del self.block_tables[seq.seq_id]

    def reset(self) -> None:
        for block_table in self.block_tables.values():
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
247
            self._free_block_table(block_table)
Woosuk Kwon's avatar
Woosuk Kwon committed
248
        self.block_tables.clear()
Woosuk Kwon's avatar
Woosuk Kwon committed
249
250
251
252

    def get_block_table(self, seq: Sequence) -> List[int]:
        block_table = self.block_tables[seq.seq_id]
        return [block.block_number for block in block_table]
253
254
255
256
257
258

    def get_num_free_gpu_blocks(self) -> int:
        return self.gpu_allocator.get_num_free_blocks()

    def get_num_free_cpu_blocks(self) -> int:
        return self.cpu_allocator.get_num_free_blocks()