block_manager.py 9.2 KB
Newer Older
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
1
from typing import Dict, List, Optional, Set, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
5
6
7
8
9

from cacheflow.block import PhysicalTokenBlock
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceStatus
from cacheflow.utils import Device


Woosuk Kwon's avatar
Woosuk Kwon committed
10
class BlockManager:
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
15
16
17

    def __init__(
        self,
        device: Device,
        block_size: int,
        num_blocks: int,
    ) -> None:
18
19
20
        if block_size not in [8, 16]:
            raise ValueError(f'Unsupported block size: {block_size}'
                             'The block size must be either 8 or 16.')
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        self.device = device
        self.block_size = block_size
        self.num_blocks = num_blocks

        # Initialize the free blocks.
        # TODO(woosuk): Make this a priority queue.
        self.free_blocks = [
            PhysicalTokenBlock(device=device, block_number=i, block_size=block_size)
            for i in range(num_blocks)
        ]

    def allocate(self) -> PhysicalTokenBlock:
        if not self.free_blocks:
            raise ValueError('Out of memory! '
                             f'No more free blocks are available.')
        block = self.free_blocks.pop()
        block.ref_count = 1
        return block

    def free(self, block: PhysicalTokenBlock) -> None:
        if block.ref_count == 0:
            raise ValueError('Double free! '
                             f'The block {block} is already freed.')
        block.ref_count -= 1
        if block.ref_count == 0:
            self.free_blocks.append(block)

    def get_num_free_blocks(self) -> int:
        return len(self.free_blocks)


# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]


class BlockSpaceManager:

    def __init__(
        self,
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
        self.block_size = block_size
        self.num_total_gpu_blocks = num_gpu_blocks
        self.num_total_cpu_blocks = num_cpu_blocks

Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
        self.gpu_allocator = BlockManager(Device.GPU, block_size, num_gpu_blocks)
        self.cpu_allocator = BlockManager(Device.CPU, block_size, num_cpu_blocks)
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
72
73
74

        # Mapping: seq_id -> BlockTable.
        self.block_tables: Dict[int, BlockTable] = {}

    def can_allocate(self, seq_group: SequenceGroup) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
75
        # NOTE: Here we assume that all sequences in the group have the same prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78
79
80
81
        seq = seq_group.seqs[0]
        num_required_blocks = len(seq.logical_token_blocks)
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
        return num_required_blocks <= num_free_gpu_blocks

    def allocate(self, seq_group: SequenceGroup) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
82
        # NOTE: Here we assume that all sequences in the group have the same prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        seq = seq_group.seqs[0]

        # Allocate new physical token blocks that will store the prompt tokens.
        block_table: BlockTable = []
        for _ in range(len(seq.logical_token_blocks)):
            block = self.gpu_allocator.allocate()
            # Set the reference counts of the token blocks.
            block.ref_count = seq_group.num_seqs()
            block_table.append(block)

        # Assign the block table for each sequence.
        for seq in seq_group.seqs:
            self.block_tables[seq.seq_id] = block_table.copy()

    def can_append(self, seq_group: SequenceGroup) -> bool:
        # Simple heuristic: If there is at least one free block
        # for each sequence, we can append.
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
Woosuk Kwon's avatar
Woosuk Kwon committed
101
        num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        return num_seqs <= num_free_gpu_blocks

    def append(self, seq: Sequence) -> Optional[Tuple[int, int]]:
        """Allocate a physical slot for the new token."""
        logical_blocks = seq.logical_token_blocks
        block_table = self.block_tables[seq.seq_id]

        if len(block_table) < len(logical_blocks):
            # The sequence has a new logical block.
            # Allocate a new physical block.
            block = self.gpu_allocator.allocate()
            block_table.append(block)
            return None

        # We want to append the token to the last physical block.
        last_block = block_table[-1]
        assert last_block.device == Device.GPU
        if last_block.ref_count == 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
120
            # Not shared with other sequences. Appendable.
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
124
            return None
        else:
            # The last block is shared with other sequences.
            # Copy on Write: Allocate a new block and copy the tokens.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
125
            new_block = self.gpu_allocator.allocate()
Woosuk Kwon's avatar
Woosuk Kwon committed
126
            block_table[-1] = new_block
Woosuk Kwon's avatar
Woosuk Kwon committed
127
            self.gpu_allocator.free(last_block)
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
128
            return last_block.block_number, new_block.block_number
Woosuk Kwon's avatar
Woosuk Kwon committed
129

Woosuk Kwon's avatar
Woosuk Kwon committed
130
    def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
        # NOTE: fork does not allocate a new physical block.
        # Thus, it is always safe from OOM.
Woosuk Kwon's avatar
Woosuk Kwon committed
133
        src_block_table = self.block_tables[parent_seq.seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        self.block_tables[child_seq.seq_id] = src_block_table.copy()
        for block in src_block_table:
            block.ref_count += 1

    def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
        # NOTE: Here, we assume that the physical blocks are only shared by
        # the sequences in the same group.
        blocks: Set[PhysicalTokenBlock] = set()
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.FINISHED:
                continue
            block_table = self.block_tables[seq.seq_id]
            for block in block_table:
                blocks.add(block)
        return list(blocks)

    def can_swap_in(self, seq_group: SequenceGroup) -> bool:
        blocks = self._get_physical_blocks(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
152
        num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155
156
        num_free_blocks = self.gpu_allocator.get_num_free_blocks()
        # NOTE: Conservatively, we assume that every sequence will allocate
        # at least one free block right after the swap-in.
        # NOTE: This should match the logic in can_append().
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        return len(blocks) + num_swapped_seqs <= num_free_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159

    def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
160
161
        # CPU block -> GPU block.
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.FINISHED:
                continue
165
            new_block_table: BlockTable = []
Woosuk Kwon's avatar
Woosuk Kwon committed
166
167
168
169
            block_table = self.block_tables[seq.seq_id]

            for cpu_block in block_table:
                if cpu_block in mapping:
Woosuk Kwon's avatar
Woosuk Kwon committed
170
                    gpu_block = mapping[cpu_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
171
                    gpu_block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
172
173
174
175
                else:
                    gpu_block = self.gpu_allocator.allocate()
                    mapping[cpu_block] = gpu_block
                new_block_table.append(gpu_block)
Woosuk Kwon's avatar
Woosuk Kwon committed
176
177
                # Free the CPU block swapped in to GPU.
                self.cpu_allocator.free(cpu_block)
178
179
180
181
182
183
184
            self.block_tables[seq.seq_id] = new_block_table

        block_number_mapping = {
            cpu_block.block_number: gpu_block.block_number
            for cpu_block, gpu_block in mapping.items()
        }
        return block_number_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
185
186
187
188
189
190

    def can_swap_out(self, seq_group: SequenceGroup) -> bool:
        blocks = self._get_physical_blocks(seq_group)
        return len(blocks) <= self.cpu_allocator.get_num_free_blocks()

    def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
191
192
        # GPU block -> CPU block.
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
193
194
195
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.FINISHED:
                continue
196
            new_block_table: BlockTable = []
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
199
            block_table = self.block_tables[seq.seq_id]

            for gpu_block in block_table:
200
                if gpu_block in mapping:
Woosuk Kwon's avatar
Woosuk Kwon committed
201
                    cpu_block = mapping[gpu_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
202
                    cpu_block.ref_count += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
203
204
205
                else:
                    cpu_block = self.cpu_allocator.allocate()
                    mapping[gpu_block] = cpu_block
206
                new_block_table.append(cpu_block)
Woosuk Kwon's avatar
Woosuk Kwon committed
207
208
                # Free the GPU block swapped out to CPU.
                self.gpu_allocator.free(gpu_block)
209
210
211
212
213
214
215
            self.block_tables[seq.seq_id] = new_block_table

        block_number_mapping = {
            gpu_block.block_number: cpu_block.block_number
            for gpu_block, cpu_block in mapping.items()
        }
        return block_number_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
216

Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
217
218
    def _free_block_table(self, block_table: BlockTable) -> None:
        for block in block_table:
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
221
222
223
224
225
            if block.device == Device.GPU:
                self.gpu_allocator.free(block)
            else:
                self.cpu_allocator.free(block)

    def free(self, seq: Sequence) -> None:
        block_table = self.block_tables[seq.seq_id]
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
226
        self._free_block_table(block_table)
Woosuk Kwon's avatar
Woosuk Kwon committed
227
228
229
230
        del self.block_tables[seq.seq_id]

    def reset(self) -> None:
        for block_table in self.block_tables.values():
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
231
            self._free_block_table(block_table)
Woosuk Kwon's avatar
Woosuk Kwon committed
232
        self.block_tables.clear()
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234
235
236

    def get_block_table(self, seq: Sequence) -> List[int]:
        block_table = self.block_tables[seq.seq_id]
        return [block.block_number for block in block_table]