scheduler.py 9.63 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
from typing import Dict, List, Tuple

from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceStatus

8
9
_MAX_NUM_BATCHED_TOKENS = 2048

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12

class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
13
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
14
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
15
        controllers: List,
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
22
23
24
25
        self.controllers = controllers
        self.block_size = block_size
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks

        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30
31
        self.block_manager = BlockSpaceManager(
            block_size=block_size,
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
        )

32
33
        # Running sequence groups (FIFO).
        self.running: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
39
40
41
42
43
        # Mapping: group_id -> num_steps.
        self.num_steps: Dict[int, int] = {}
        # Mapping: group_id -> max_num_steps.
        self.max_num_steps: Dict[int, int] = {}
        # Mapping: group_id -> stop_token_ids.
        self.stop_token_ids: Dict[int, List[int]] = {}

        # Swapped sequence groups (LIFO).
        self.swapped: List[SequenceGroup] = []
        # Pending sequence groups (FIFO).
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
        self.pending: List[SequenceGroup] = []

Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
50
51
52
53
    def _free_seq(self, seq: Sequence) -> None:
        seq.status = SequenceStatus.FINISHED
        self.block_manager.free(seq)

    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
        for seq in seq_group.seqs:
            seq.status = SequenceStatus.RUNNING
54
        self.running.append(seq_group)
55
        # FIXME(woosuk): Support interactive generation.
Woosuk Kwon's avatar
Woosuk Kwon committed
56
        self.num_steps[seq_group.group_id] = 0
Woosuk Kwon's avatar
Woosuk Kwon committed
57

58
59
60
61
62
    def _append(
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, int],
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
65
66
67
68
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.FINISHED:
                continue
            ret = self.block_manager.append(seq)
            if ret is not None:
                src_block, dst_block = ret
69
                blocks_to_copy[src_block] = dst_block
Woosuk Kwon's avatar
Woosuk Kwon committed
70

71
72
73
74
75
    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
76
        mapping = self.block_manager.swap_in(seq_group)
77
        blocks_to_swap_in.update(mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.SWAPPED:
                seq.status = SequenceStatus.RUNNING
81
        self.running.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
82

83
84
85
86
87
    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        assert self.block_manager.can_swap_out(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
89
        mapping = self.block_manager.swap_out(seq_group)
90
        blocks_to_swap_out.update(mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
94
95
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.RUNNING:
                seq.status = SequenceStatus.SWAPPED
        self.swapped.append(seq_group)

Woosuk Kwon's avatar
Woosuk Kwon committed
96
    def step(self) -> None:
97
98
99
100
101
        # Blocks that need to be swaped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
        blocks_to_copy: Dict[int, int] = {}

Woosuk Kwon's avatar
Woosuk Kwon committed
102
        # 1. Reserve new slots for the running sequences.
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
105
        # NOTE: Here we implicitly assume FCFS scheduling.
        # That is, the most recently added sequence group is the first
        # to be swapped out.
106
107
        victim_idx = len(self.running) - 1
        for i, seq_group in enumerate(self.running):
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
111
112
            if i > victim_idx:
                # The i-th sequence group has already been swapped out.
                break
            # OOM. Swap out the victim sequence groups.
            while not self.block_manager.can_append(seq_group):
113
                victim_seq_group = self.running[victim_idx]
114
                self._swap_out(victim_seq_group, blocks_to_swap_out)
Woosuk Kwon's avatar
Woosuk Kwon committed
115
116
117
118
119
                victim_idx -= 1
                if i > victim_idx:
                    # No other sequence groups can be swapped out.
                    break
            else:
120
                self._append(seq_group, blocks_to_copy)
121
        self.running = self.running[:victim_idx + 1]
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
124
125
126
127

        # 2. Swap in the swapped sequences if possible.
        # NOTE: Here we implicitly assume FCFS scheduling.
        # The swapped sequences are in LIFO order.
        for i, seq_group in enumerate(reversed(self.swapped)):
            if self.block_manager.can_swap_in(seq_group):
128
129
                self._swap_in(seq_group, blocks_to_swap_in)
                self._append(seq_group, blocks_to_copy)
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
134
135
136
137
            else:
                # OOM. Stop swapping.
                self.swapped = self.swapped[:len(self.swapped) - i]
                break
        else:
            # All swapped sequences are swapped in.
            self.swapped.clear()

138
139
140
141
142
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
            for seq_group in self.running
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
143
144
        # 3. Join new sequences if possible.
        # NOTE: Here we implicitly assume FCFS scheduling.
145
        # TODO(woosuk): Add a batching policy to control the batch size.
Woosuk Kwon's avatar
Woosuk Kwon committed
146
        if not self.swapped:
147
            # FIXME(woosuk): Acquire a lock to protect pending.
Woosuk Kwon's avatar
Woosuk Kwon committed
148
            for i, seq_group in enumerate(self.pending):
149
                num_prompt_tokens = seq_group.seqs[0].get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
150
                if self.block_manager.can_allocate(seq_group):
151
152
153
154
155
156
157
158
                    if (num_batched_tokens + num_prompt_tokens
                        <= _MAX_NUM_BATCHED_TOKENS):
                        self._allocate(seq_group)
                        num_batched_tokens += num_prompt_tokens
                        continue

                self.pending = self.pending[i:]
                break
Woosuk Kwon's avatar
Woosuk Kwon committed
159
160
            else:
                self.pending.clear()
Woosuk Kwon's avatar
Woosuk Kwon committed
161

162
163
164
        # Ensure that swap-in and swap-out never happen at the same timestep.
        if blocks_to_swap_in:
            assert not blocks_to_swap_out
Woosuk Kwon's avatar
Woosuk Kwon committed
165

Woosuk Kwon's avatar
Woosuk Kwon committed
166
        # 4. Create input data structures.
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        prompt_tokens: Dict[int, List[int]] = {}
        generation_tokens: Dict[int, int] = {}
        context_lens: Dict[int, int] = {}
        block_tables: Dict[int, List[int]] = {}
        for seq_group in self.running:
            group_id = seq_group.group_id
            num_steps = self.num_steps[group_id]
            # NOTE(woosuk): We assume that the number of steps is 0
            # for the prompt sequences.
            is_prompt = num_steps == 0
            for seq in seq_group.seqs:
                if seq.status != SequenceStatus.RUNNING:
                    continue

                seq_id = seq.seq_id
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
                if is_prompt:
                    prompt_tokens[seq_id] = seq.get_token_ids()
                else:
                    generation_tokens[seq_id] = seq.get_token_ids()[-1]
                    context_lens[seq_id] = seq.get_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
189
        # 5. Execute the first stage of the pipeline.
Woosuk Kwon's avatar
Woosuk Kwon committed
190
        self.controllers[0].execute_stage(
191
192
193
194
            prompt_tokens,
            generation_tokens,
            context_lens,
            block_tables,
195
196
197
            blocks_to_swap_in,
            blocks_to_swap_out,
            blocks_to_copy,
Woosuk Kwon's avatar
Woosuk Kwon committed
198
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
199

Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202
203
204
    def post_step(
        self,
        next_tokens: Dict[int, Tuple[int, int]],
    ) -> None:
        # Update the running sequences and free blocks.
205
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
            group_id = seq_group.group_id
            self.num_steps[group_id] += 1
            stop_token_ids = self.stop_token_ids[group_id]

            for seq in seq_group.seqs:
                if seq.status == SequenceStatus.FINISHED:
                    continue

                parent_seq_id, next_token = next_tokens[seq.seq_id]
                if seq.seq_id != parent_seq_id:
                    # The sequence is a fork of the parent sequence (beam search).
                    # Free the current sequence.
                    self.block_manager.free(seq)
                    # Fork the parent sequence.
                    parent_seq = seq_group.find(parent_seq_id)
                    seq.logical_token_blocks = parent_seq.logical_token_blocks.copy()
                    self.block_manager.fork(parent_seq, seq)

                # Append a new token to the sequence.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
225
                seq.append([next_token])
Woosuk Kwon's avatar
Woosuk Kwon committed
226
227
228
229
230
231
232
233
234
235
236

                # Check if the sequence has generated a stop token.
                if next_token in stop_token_ids:
                    self._free_seq(seq)
                    continue

                # Check if the sequence has reached the maximum number of steps.
                if self.num_steps[group_id] == self.max_num_steps[group_id]:
                    self._free_seq(seq)
                    continue

237
238
239
        # Update the running sequences.
        running: List[SequenceGroup] = []
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
240
            if all(seq.status == SequenceStatus.FINISHED for seq in seq_group.seqs):
Woosuk Kwon's avatar
Woosuk Kwon committed
241
242
243
244
                del self.num_steps[seq_group.group_id]
                del self.max_num_steps[seq_group.group_id]
                del self.stop_token_ids[seq_group.group_id]
                # TODO: Return the seq_group to the client.
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246
247
248
249
250
                from transformers import AutoTokenizer
                tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')
                for seq in seq_group.seqs:
                    token_ids = seq.get_token_ids()
                    output = tokenizer.decode(token_ids, skip_special_tokens=True)
                    print(f'Seq {seq.seq_id}: {output}')
Woosuk Kwon's avatar
Woosuk Kwon committed
251
            else:
252
253
                running.append(seq_group)
        self.running = running