block_manager.py 11.7 KB
Newer Older
1
"""A block manager that manages token blocks."""
2
import enum
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
3
from typing import Dict, List, Optional, Set, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
4

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
from vllm.block import PhysicalTokenBlock
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
Woosuk Kwon's avatar
Woosuk Kwon committed
8
9


Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
10
class BlockAllocator:
11
12
13
14
15
16
    """Manages free physical token blocks for a device.

    The allocator maintains a list of free blocks and allocates a block when
    requested. When a block is freed, its reference count is decremented. If
    the reference count becomes zero, the block is added back to the free list.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
19
20
21
22
23
24
25
26
27
28

    def __init__(
        self,
        device: Device,
        block_size: int,
        num_blocks: int,
    ) -> None:
        self.device = device
        self.block_size = block_size
        self.num_blocks = num_blocks

        # Initialize the free blocks.
29
30
        self.free_blocks: List[PhysicalTokenBlock] = []
        for i in range(num_blocks):
31
32
33
            block = PhysicalTokenBlock(device=device,
                                       block_number=i,
                                       block_size=block_size)
34
            self.free_blocks.append(block)
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37

    def allocate(self) -> PhysicalTokenBlock:
        if not self.free_blocks:
38
            raise ValueError("Out of memory! No free blocks are available.")
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
42
43
44
        block = self.free_blocks.pop()
        block.ref_count = 1
        return block

    def free(self, block: PhysicalTokenBlock) -> None:
        if block.ref_count == 0:
45
            raise ValueError(f"Double free! {block} is already freed.")
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
50
51
52
53
54
55
56
57
        block.ref_count -= 1
        if block.ref_count == 0:
            self.free_blocks.append(block)

    def get_num_free_blocks(self) -> int:
        return len(self.free_blocks)


# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]


58
59
60
61
62
63
64
65
66
67
68
69
70
71
class AllocStatus(enum.Enum):
    """Result for BlockSpaceManager.can_allocate

    1. Ok: seq_group can be allocated now.
    2. Later: seq_group cannot be allocated.
      The capacity of allocator is larger than seq_group required.
    3. Never: seq_group can never be allocated.
      The seq_group is too large to allocated in GPU.
    """
    OK = enum.auto()
    LATER = enum.auto()
    NEVER = enum.auto()


Woosuk Kwon's avatar
Woosuk Kwon committed
72
class BlockSpaceManager:
73
    """Manages the mapping between logical and physical token blocks."""
Woosuk Kwon's avatar
Woosuk Kwon committed
74
75
76
77
78
79

    def __init__(
        self,
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
80
        watermark: float = 0.01,
81
        sliding_window: Optional[int] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
85
    ) -> None:
        self.block_size = block_size
        self.num_total_gpu_blocks = num_gpu_blocks
        self.num_total_cpu_blocks = num_cpu_blocks
86
87
88
89
90
91
92

        self.block_sliding_window = None
        if sliding_window is not None:
            assert sliding_window % block_size == 0, (sliding_window,
                                                      block_size)
            self.block_sliding_window = sliding_window // block_size

93
94
        self.watermark = watermark
        assert watermark >= 0.0
Woosuk Kwon's avatar
Woosuk Kwon committed
95

96
        self.watermark_blocks = int(watermark * num_gpu_blocks)
97
98
99
100
        self.gpu_allocator = BlockAllocator(Device.GPU, block_size,
                                            num_gpu_blocks)
        self.cpu_allocator = BlockAllocator(Device.CPU, block_size,
                                            num_cpu_blocks)
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
        # Mapping: seq_id -> BlockTable.
        self.block_tables: Dict[int, BlockTable] = {}

104
    def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
105
106
        # FIXME(woosuk): Here we assume that all sequences in the group share
        # the same prompt. This may not be true for preempted sequences.
107
        seq = seq_group.get_seqs()[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
108
        num_required_blocks = len(seq.logical_token_blocks)
109
110
111
        if self.block_sliding_window is not None:
            num_required_blocks = min(num_required_blocks,
                                      self.block_sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
113

114
        # Use watermark to avoid frequent cache eviction.
115
116
117
118
119
120
121
        if (self.num_total_gpu_blocks - num_required_blocks <
                self.watermark_blocks):
            return AllocStatus.NEVER
        if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
            return AllocStatus.OK
        else:
            return AllocStatus.LATER
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123

    def allocate(self, seq_group: SequenceGroup) -> None:
124
125
        # NOTE: Here we assume that all sequences in the group have the same
        # prompt.
126
        seq = seq_group.get_seqs()[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
129

        # Allocate new physical token blocks that will store the prompt tokens.
        block_table: BlockTable = []
130
131
132
133
134
135
        for logical_idx in range(len(seq.logical_token_blocks)):
            if (self.block_sliding_window is not None
                    and logical_idx >= self.block_sliding_window):
                block = block_table[logical_idx % self.block_sliding_window]
            else:
                block = self.gpu_allocator.allocate()
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
139
140
            # Set the reference counts of the token blocks.
            block.ref_count = seq_group.num_seqs()
            block_table.append(block)

        # Assign the block table for each sequence.
141
        for seq in seq_group.get_seqs():
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
            self.block_tables[seq.seq_id] = block_table.copy()

144
    def can_append_slot(self, seq_group: SequenceGroup) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147
        # Simple heuristic: If there is at least one free block
        # for each sequence, we can append.
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
Woosuk Kwon's avatar
Woosuk Kwon committed
148
        num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
        return num_seqs <= num_free_gpu_blocks

151
152
    def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
        """Allocate a physical slot for a new token."""
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155
156
        logical_blocks = seq.logical_token_blocks
        block_table = self.block_tables[seq.seq_id]

        if len(block_table) < len(logical_blocks):
157
158
159
160
161
162
163
164
165
166
167
            if (self.block_sliding_window
                    and len(block_table) >= self.block_sliding_window):
                # re-use a block
                block_table.append(block_table[len(block_table) %
                                               self.block_sliding_window])
            else:
                # The sequence has a new logical block.
                # Allocate a new physical block.
                block = self.gpu_allocator.allocate()
                block_table.append(block)
                return None
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
170
171
172

        # We want to append the token to the last physical block.
        last_block = block_table[-1]
        assert last_block.device == Device.GPU
        if last_block.ref_count == 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
173
            # Not shared with other sequences. Appendable.
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
177
            return None
        else:
            # The last block is shared with other sequences.
            # Copy on Write: Allocate a new block and copy the tokens.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
178
            new_block = self.gpu_allocator.allocate()
Woosuk Kwon's avatar
Woosuk Kwon committed
179
            block_table[-1] = new_block
Woosuk Kwon's avatar
Woosuk Kwon committed
180
            self.gpu_allocator.free(last_block)
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
181
            return last_block.block_number, new_block.block_number
Woosuk Kwon's avatar
Woosuk Kwon committed
182

Woosuk Kwon's avatar
Woosuk Kwon committed
183
    def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
184
185
        # NOTE: fork does not allocate a new physical block.
        # Thus, it is always safe from OOM.
Woosuk Kwon's avatar
Woosuk Kwon committed
186
        src_block_table = self.block_tables[parent_seq.seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
187
188
189
190
        self.block_tables[child_seq.seq_id] = src_block_table.copy()
        for block in src_block_table:
            block.ref_count += 1

191
192
    def _get_physical_blocks(
            self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
Woosuk Kwon's avatar
Woosuk Kwon committed
193
194
195
        # NOTE: Here, we assume that the physical blocks are only shared by
        # the sequences in the same group.
        blocks: Set[PhysicalTokenBlock] = set()
196
        for seq in seq_group.get_seqs():
197
            if seq.is_finished():
Woosuk Kwon's avatar
Woosuk Kwon committed
198
                continue
199
            blocks.update(self.block_tables[seq.seq_id])
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202
203
        return list(blocks)

    def can_swap_in(self, seq_group: SequenceGroup) -> bool:
        blocks = self._get_physical_blocks(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
204
        num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
207
        num_free_blocks = self.gpu_allocator.get_num_free_blocks()
        # NOTE: Conservatively, we assume that every sequence will allocate
        # at least one free block right after the swap-in.
208
        # NOTE: This should match the logic in can_append_slot().
209
210
        num_required_blocks = len(blocks) + num_swapped_seqs
        return num_free_blocks - num_required_blocks >= self.watermark_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
211
212

    def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
213
214
        # CPU block -> GPU block.
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
215
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
216
            new_block_table: BlockTable = []
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
219
220
            block_table = self.block_tables[seq.seq_id]

            for cpu_block in block_table:
                if cpu_block in mapping:
Woosuk Kwon's avatar
Woosuk Kwon committed
221
                    gpu_block = mapping[cpu_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
222
                    gpu_block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225
226
                else:
                    gpu_block = self.gpu_allocator.allocate()
                    mapping[cpu_block] = gpu_block
                new_block_table.append(gpu_block)
Woosuk Kwon's avatar
Woosuk Kwon committed
227
228
                # Free the CPU block swapped in to GPU.
                self.cpu_allocator.free(cpu_block)
229
230
231
232
233
234
235
            self.block_tables[seq.seq_id] = new_block_table

        block_number_mapping = {
            cpu_block.block_number: gpu_block.block_number
            for cpu_block, gpu_block in mapping.items()
        }
        return block_number_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
236
237
238
239
240
241

    def can_swap_out(self, seq_group: SequenceGroup) -> bool:
        blocks = self._get_physical_blocks(seq_group)
        return len(blocks) <= self.cpu_allocator.get_num_free_blocks()

    def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
242
243
        # GPU block -> CPU block.
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
244
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
245
            new_block_table: BlockTable = []
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
248
            block_table = self.block_tables[seq.seq_id]

            for gpu_block in block_table:
249
                if gpu_block in mapping:
Woosuk Kwon's avatar
Woosuk Kwon committed
250
                    cpu_block = mapping[gpu_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
251
                    cpu_block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
252
253
254
                else:
                    cpu_block = self.cpu_allocator.allocate()
                    mapping[gpu_block] = cpu_block
255
                new_block_table.append(cpu_block)
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
                # Free the GPU block swapped out to CPU.
                self.gpu_allocator.free(gpu_block)
258
259
260
261
262
263
264
            self.block_tables[seq.seq_id] = new_block_table

        block_number_mapping = {
            gpu_block.block_number: cpu_block.block_number
            for gpu_block, cpu_block in mapping.items()
        }
        return block_number_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
265

Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
266
    def _free_block_table(self, block_table: BlockTable) -> None:
267
        for block in set(block_table):
Woosuk Kwon's avatar
Woosuk Kwon committed
268
269
270
271
272
273
            if block.device == Device.GPU:
                self.gpu_allocator.free(block)
            else:
                self.cpu_allocator.free(block)

    def free(self, seq: Sequence) -> None:
274
275
276
        if seq.seq_id not in self.block_tables:
            # Already freed or haven't been scheduled yet.
            return
Woosuk Kwon's avatar
Woosuk Kwon committed
277
        block_table = self.block_tables[seq.seq_id]
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
278
        self._free_block_table(block_table)
Woosuk Kwon's avatar
Woosuk Kwon committed
279
280
281
282
        del self.block_tables[seq.seq_id]

    def reset(self) -> None:
        for block_table in self.block_tables.values():
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
283
            self._free_block_table(block_table)
Woosuk Kwon's avatar
Woosuk Kwon committed
284
        self.block_tables.clear()
Woosuk Kwon's avatar
Woosuk Kwon committed
285
286
287
288

    def get_block_table(self, seq: Sequence) -> List[int]:
        block_table = self.block_tables[seq.seq_id]
        return [block.block_number for block in block_table]
289
290
291
292
293
294

    def get_num_free_gpu_blocks(self) -> int:
        return self.gpu_allocator.get_num_free_blocks()

    def get_num_free_cpu_blocks(self) -> int:
        return self.cpu_allocator.get_num_free_blocks()