block_manager.py 10.1 KB
Newer Older
1
"""A block manager that manages token blocks."""
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
2
from typing import Dict, List, Optional, Set, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

from cacheflow.block import PhysicalTokenBlock
5
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
8
from cacheflow.utils import Device


Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
9
class BlockAllocator:
10
11
12
13
14
15
    """Manages free physical token blocks for a device.

    The allocator maintains a list of free blocks and allocates a block when
    requested. When a block is freed, its reference count is decremented. If
    the reference count becomes zero, the block is added back to the free list.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
20
21
22
23
24
25
26
27

    def __init__(
        self,
        device: Device,
        block_size: int,
        num_blocks: int,
    ) -> None:
        self.device = device
        self.block_size = block_size
        self.num_blocks = num_blocks

        # Initialize the free blocks.
28
29
30
31
32
        self.free_blocks: List[PhysicalTokenBlock] = []
        for i in range(num_blocks):
            block = PhysicalTokenBlock(
                device=device, block_number=i, block_size=block_size)
            self.free_blocks.append(block)
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35

    def allocate(self) -> PhysicalTokenBlock:
        if not self.free_blocks:
36
            raise ValueError("Out of memory! No free blocks are available.")
Woosuk Kwon's avatar
Woosuk Kwon committed
37
38
39
40
41
42
        block = self.free_blocks.pop()
        block.ref_count = 1
        return block

    def free(self, block: PhysicalTokenBlock) -> None:
        if block.ref_count == 0:
43
            raise ValueError(f"Double free! {block} is already freed.")
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
46
47
48
49
50
51
52
53
54
55
56
        block.ref_count -= 1
        if block.ref_count == 0:
            self.free_blocks.append(block)

    def get_num_free_blocks(self) -> int:
        return len(self.free_blocks)


# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]


class BlockSpaceManager:
57
    """Manages the mapping between logical and physical token blocks."""
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
61
62
63

    def __init__(
        self,
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
64
        watermark: float = 0.01,
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
67
68
    ) -> None:
        self.block_size = block_size
        self.num_total_gpu_blocks = num_gpu_blocks
        self.num_total_cpu_blocks = num_cpu_blocks
69
70
        self.watermark = watermark
        assert watermark >= 0.0
Woosuk Kwon's avatar
Woosuk Kwon committed
71

72
        self.watermark_blocks = int(watermark * num_gpu_blocks)
73
74
75
76
        self.gpu_allocator = BlockAllocator(Device.GPU, block_size,
                                            num_gpu_blocks)
        self.cpu_allocator = BlockAllocator(Device.CPU, block_size,
                                            num_cpu_blocks)
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
80
        # Mapping: seq_id -> BlockTable.
        self.block_tables: Dict[int, BlockTable] = {}

    def can_allocate(self, seq_group: SequenceGroup) -> bool:
81
82
        # FIXME(woosuk): Here we assume that all sequences in the group share
        # the same prompt. This may not be true for preempted sequences.
83
        seq = seq_group.get_seqs()[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
        num_required_blocks = len(seq.logical_token_blocks)
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
86
87
        # Use watermark to avoid frequent cache eviction.
        return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89

    def allocate(self, seq_group: SequenceGroup) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
90
        # NOTE: Here we assume that all sequences in the group have the same prompt.
91
        seq = seq_group.get_seqs()[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95
96
97
98
99
100
101

        # Allocate new physical token blocks that will store the prompt tokens.
        block_table: BlockTable = []
        for _ in range(len(seq.logical_token_blocks)):
            block = self.gpu_allocator.allocate()
            # Set the reference counts of the token blocks.
            block.ref_count = seq_group.num_seqs()
            block_table.append(block)

        # Assign the block table for each sequence.
102
        for seq in seq_group.get_seqs():
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
            self.block_tables[seq.seq_id] = block_table.copy()

105
    def can_append_slot(self, seq_group: SequenceGroup) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
106
107
108
        # Simple heuristic: If there is at least one free block
        # for each sequence, we can append.
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
Woosuk Kwon's avatar
Woosuk Kwon committed
109
        num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
        return num_seqs <= num_free_gpu_blocks

112
113
    def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
        """Allocate a physical slot for a new token."""
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        logical_blocks = seq.logical_token_blocks
        block_table = self.block_tables[seq.seq_id]

        if len(block_table) < len(logical_blocks):
            # The sequence has a new logical block.
            # Allocate a new physical block.
            block = self.gpu_allocator.allocate()
            block_table.append(block)
            return None

        # We want to append the token to the last physical block.
        last_block = block_table[-1]
        assert last_block.device == Device.GPU
        if last_block.ref_count == 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
128
            # Not shared with other sequences. Appendable.
Woosuk Kwon's avatar
Woosuk Kwon committed
129
130
131
132
            return None
        else:
            # The last block is shared with other sequences.
            # Copy on Write: Allocate a new block and copy the tokens.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
133
            new_block = self.gpu_allocator.allocate()
Woosuk Kwon's avatar
Woosuk Kwon committed
134
            block_table[-1] = new_block
Woosuk Kwon's avatar
Woosuk Kwon committed
135
            self.gpu_allocator.free(last_block)
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
136
            return last_block.block_number, new_block.block_number
Woosuk Kwon's avatar
Woosuk Kwon committed
137

Woosuk Kwon's avatar
Woosuk Kwon committed
138
    def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
        # NOTE: fork does not allocate a new physical block.
        # Thus, it is always safe from OOM.
Woosuk Kwon's avatar
Woosuk Kwon committed
141
        src_block_table = self.block_tables[parent_seq.seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144
145
146
147
148
149
        self.block_tables[child_seq.seq_id] = src_block_table.copy()
        for block in src_block_table:
            block.ref_count += 1

    def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
        # NOTE: Here, we assume that the physical blocks are only shared by
        # the sequences in the same group.
        blocks: Set[PhysicalTokenBlock] = set()
150
        for seq in seq_group.get_seqs():
Zhuohan Li's avatar
Zhuohan Li committed
151
            if SequenceStatus.is_finished(seq.status):
Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
154
155
156
157
158
159
                continue
            block_table = self.block_tables[seq.seq_id]
            for block in block_table:
                blocks.add(block)
        return list(blocks)

    def can_swap_in(self, seq_group: SequenceGroup) -> bool:
        blocks = self._get_physical_blocks(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
160
        num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
Woosuk Kwon's avatar
Woosuk Kwon committed
161
162
163
        num_free_blocks = self.gpu_allocator.get_num_free_blocks()
        # NOTE: Conservatively, we assume that every sequence will allocate
        # at least one free block right after the swap-in.
164
        # NOTE: This should match the logic in can_append_slot().
165
166
        num_required_blocks = len(blocks) + num_swapped_seqs
        return num_free_blocks - num_required_blocks >= self.watermark_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
167
168

    def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
169
170
        # CPU block -> GPU block.
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
171
        for seq in seq_group.get_seqs():
Zhuohan Li's avatar
Zhuohan Li committed
172
            if SequenceStatus.is_finished(seq.status):
Woosuk Kwon's avatar
Woosuk Kwon committed
173
                continue
174
            new_block_table: BlockTable = []
Woosuk Kwon's avatar
Woosuk Kwon committed
175
176
177
178
            block_table = self.block_tables[seq.seq_id]

            for cpu_block in block_table:
                if cpu_block in mapping:
Woosuk Kwon's avatar
Woosuk Kwon committed
179
                    gpu_block = mapping[cpu_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
180
                    gpu_block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
181
182
183
184
                else:
                    gpu_block = self.gpu_allocator.allocate()
                    mapping[cpu_block] = gpu_block
                new_block_table.append(gpu_block)
Woosuk Kwon's avatar
Woosuk Kwon committed
185
186
                # Free the CPU block swapped in to GPU.
                self.cpu_allocator.free(cpu_block)
187
188
189
190
191
192
193
            self.block_tables[seq.seq_id] = new_block_table

        block_number_mapping = {
            cpu_block.block_number: gpu_block.block_number
            for cpu_block, gpu_block in mapping.items()
        }
        return block_number_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
194
195
196
197
198
199

    def can_swap_out(self, seq_group: SequenceGroup) -> bool:
        blocks = self._get_physical_blocks(seq_group)
        return len(blocks) <= self.cpu_allocator.get_num_free_blocks()

    def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
200
201
        # GPU block -> CPU block.
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
202
        for seq in seq_group.get_seqs():
Zhuohan Li's avatar
Zhuohan Li committed
203
            if SequenceStatus.is_finished(seq.status):
Woosuk Kwon's avatar
Woosuk Kwon committed
204
                continue
205
            new_block_table: BlockTable = []
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
208
            block_table = self.block_tables[seq.seq_id]

            for gpu_block in block_table:
209
                if gpu_block in mapping:
Woosuk Kwon's avatar
Woosuk Kwon committed
210
                    cpu_block = mapping[gpu_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
211
                    cpu_block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
214
                else:
                    cpu_block = self.cpu_allocator.allocate()
                    mapping[gpu_block] = cpu_block
215
                new_block_table.append(cpu_block)
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
                # Free the GPU block swapped out to CPU.
                self.gpu_allocator.free(gpu_block)
218
219
220
221
222
223
224
            self.block_tables[seq.seq_id] = new_block_table

        block_number_mapping = {
            gpu_block.block_number: cpu_block.block_number
            for gpu_block, cpu_block in mapping.items()
        }
        return block_number_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
225

Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
226
227
    def _free_block_table(self, block_table: BlockTable) -> None:
        for block in block_table:
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
230
231
232
233
234
            if block.device == Device.GPU:
                self.gpu_allocator.free(block)
            else:
                self.cpu_allocator.free(block)

    def free(self, seq: Sequence) -> None:
        block_table = self.block_tables[seq.seq_id]
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
235
        self._free_block_table(block_table)
Woosuk Kwon's avatar
Woosuk Kwon committed
236
237
238
239
        del self.block_tables[seq.seq_id]

    def reset(self) -> None:
        for block_table in self.block_tables.values():
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
240
            self._free_block_table(block_table)
Woosuk Kwon's avatar
Woosuk Kwon committed
241
        self.block_tables.clear()
Woosuk Kwon's avatar
Woosuk Kwon committed
242
243
244
245

    def get_block_table(self, seq: Sequence) -> List[int]:
        block_table = self.block_tables[seq.seq_id]
        return [block.block_number for block in block_table]
246
247
248
249
250
251

    def get_num_free_gpu_blocks(self) -> int:
        return self.gpu_allocator.get_num_free_blocks()

    def get_num_free_cpu_blocks(self) -> int:
        return self.cpu_allocator.get_num_free_blocks()