scheduler.py 9.89 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
from typing import Dict, List, Tuple

from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceStatus

8
9
_MAX_NUM_BATCHED_TOKENS = 2048

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12

class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
13
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
14
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
15
        controllers: List,
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
22
23
24
25
        self.controllers = controllers
        self.block_size = block_size
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks

        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30
31
        self.block_manager = BlockSpaceManager(
            block_size=block_size,
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
        )

32
33
        # Running sequence groups (FIFO).
        self.running: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
39
40
41
42
43
        # Mapping: group_id -> num_steps.
        self.num_steps: Dict[int, int] = {}
        # Mapping: group_id -> max_num_steps.
        self.max_num_steps: Dict[int, int] = {}
        # Mapping: group_id -> stop_token_ids.
        self.stop_token_ids: Dict[int, List[int]] = {}

        # Swapped sequence groups (LIFO).
        self.swapped: List[SequenceGroup] = []
        # Pending sequence groups (FIFO).
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
        self.pending: List[SequenceGroup] = []

Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
50
51
52
53
    def _free_seq(self, seq: Sequence) -> None:
        seq.status = SequenceStatus.FINISHED
        self.block_manager.free(seq)

    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
        for seq in seq_group.seqs:
            seq.status = SequenceStatus.RUNNING
54
        self.running.append(seq_group)
55
        # FIXME(woosuk): Support interactive generation.
Woosuk Kwon's avatar
Woosuk Kwon committed
56
        self.num_steps[seq_group.group_id] = 0
Woosuk Kwon's avatar
Woosuk Kwon committed
57

58
59
60
61
62
    def _append(
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, int],
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
65
66
67
68
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.FINISHED:
                continue
            ret = self.block_manager.append(seq)
            if ret is not None:
                src_block, dst_block = ret
69
                blocks_to_copy[src_block] = dst_block
Woosuk Kwon's avatar
Woosuk Kwon committed
70

71
72
73
74
75
    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
76
        mapping = self.block_manager.swap_in(seq_group)
77
        blocks_to_swap_in.update(mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.SWAPPED:
                seq.status = SequenceStatus.RUNNING
81
        self.running.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
82

83
84
85
86
87
    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        assert self.block_manager.can_swap_out(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
89
        mapping = self.block_manager.swap_out(seq_group)
90
        blocks_to_swap_out.update(mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
94
95
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.RUNNING:
                seq.status = SequenceStatus.SWAPPED
        self.swapped.append(seq_group)

96
97
98
99
100
101
    def pre_step(self) -> None:
        # Blocks that need to be swaped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
        blocks_to_copy: Dict[int, int] = {}

Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
104
105
        # 1. Prepare new slots for the running sequences.
        # NOTE: Here we implicitly assume FCFS scheduling.
        # That is, the most recently added sequence group is the first
        # to be swapped out.
106
107
        victim_idx = len(self.running) - 1
        for i, seq_group in enumerate(self.running):
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
111
112
            if i > victim_idx:
                # The i-th sequence group has already been swapped out.
                break
            # OOM. Swap out the victim sequence groups.
            while not self.block_manager.can_append(seq_group):
113
                victim_seq_group = self.running[victim_idx]
114
                self._swap_out(victim_seq_group, blocks_to_swap_out)
Woosuk Kwon's avatar
Woosuk Kwon committed
115
116
117
118
119
                victim_idx -= 1
                if i > victim_idx:
                    # No other sequence groups can be swapped out.
                    break
            else:
120
                self._append(seq_group, blocks_to_copy)
121
        self.running = self.running[:victim_idx + 1]
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
124
125
126
127

        # 2. Swap in the swapped sequences if possible.
        # NOTE: Here we implicitly assume FCFS scheduling.
        # The swapped sequences are in LIFO order.
        for i, seq_group in enumerate(reversed(self.swapped)):
            if self.block_manager.can_swap_in(seq_group):
128
129
                self._swap_in(seq_group, blocks_to_swap_in)
                self._append(seq_group, blocks_to_copy)
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
134
135
136
137
            else:
                # OOM. Stop swapping.
                self.swapped = self.swapped[:len(self.swapped) - i]
                break
        else:
            # All swapped sequences are swapped in.
            self.swapped.clear()

138
139
140
141
142
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
            for seq_group in self.running
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
143
144
        # 3. Join new sequences if possible.
        # NOTE: Here we implicitly assume FCFS scheduling.
145
        # TODO(woosuk): Add a batching policy to control the batch size.
Woosuk Kwon's avatar
Woosuk Kwon committed
146
        if not self.swapped:
147
            # FIXME(woosuk): Acquire a lock to protect pending.
Woosuk Kwon's avatar
Woosuk Kwon committed
148
            for i, seq_group in enumerate(self.pending):
149
                num_prompt_tokens = seq_group.seqs[0].get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
150
                if self.block_manager.can_allocate(seq_group):
151
152
153
154
155
156
157
158
                    if (num_batched_tokens + num_prompt_tokens
                        <= _MAX_NUM_BATCHED_TOKENS):
                        self._allocate(seq_group)
                        num_batched_tokens += num_prompt_tokens
                        continue

                self.pending = self.pending[i:]
                break
Woosuk Kwon's avatar
Woosuk Kwon committed
159
160
            else:
                self.pending.clear()
Woosuk Kwon's avatar
Woosuk Kwon committed
161

162
163
164
165
166
167
168
169
170
171
172
173
        # Execute step.
        self.step(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)

    def step(
        self,
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, int],
    ) -> None:
        # Ensure that swap-in and swap-out never happen at the same timestep.
        if blocks_to_swap_in:
            assert not blocks_to_swap_out
Woosuk Kwon's avatar
Woosuk Kwon committed
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        # Create input data structures.
        prompt_tokens: Dict[int, List[int]] = {}
        generation_tokens: Dict[int, int] = {}
        context_lens: Dict[int, int] = {}
        block_tables: Dict[int, List[int]] = {}
        for seq_group in self.running:
            group_id = seq_group.group_id
            num_steps = self.num_steps[group_id]
            # NOTE(woosuk): We assume that the number of steps is 0
            # for the prompt sequences.
            is_prompt = num_steps == 0
            for seq in seq_group.seqs:
                if seq.status != SequenceStatus.RUNNING:
                    continue

                seq_id = seq.seq_id
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
                if is_prompt:
                    prompt_tokens[seq_id] = seq.get_token_ids()
                else:
                    generation_tokens[seq_id] = seq.get_token_ids()[-1]
                    context_lens[seq_id] = seq.get_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
        # Execute the first stage of the pipeline.
        self.controllers[0].execute_stage(
200
201
202
203
            prompt_tokens,
            generation_tokens,
            context_lens,
            block_tables,
204
205
206
            blocks_to_swap_in,
            blocks_to_swap_out,
            blocks_to_copy,
Woosuk Kwon's avatar
Woosuk Kwon committed
207
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
208

Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211
212
213
    def post_step(
        self,
        next_tokens: Dict[int, Tuple[int, int]],
    ) -> None:
        # Update the running sequences and free blocks.
214
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
            group_id = seq_group.group_id
            self.num_steps[group_id] += 1
            stop_token_ids = self.stop_token_ids[group_id]

            for seq in seq_group.seqs:
                if seq.status == SequenceStatus.FINISHED:
                    continue

                parent_seq_id, next_token = next_tokens[seq.seq_id]
                if seq.seq_id != parent_seq_id:
                    # The sequence is a fork of the parent sequence (beam search).
                    # Free the current sequence.
                    self.block_manager.free(seq)
                    # Fork the parent sequence.
                    parent_seq = seq_group.find(parent_seq_id)
                    seq.logical_token_blocks = parent_seq.logical_token_blocks.copy()
                    self.block_manager.fork(parent_seq, seq)

                # Append a new token to the sequence.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
234
                seq.append([next_token])
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
237
238
239
240
241
242
243
244
245

                # Check if the sequence has generated a stop token.
                if next_token in stop_token_ids:
                    self._free_seq(seq)
                    continue

                # Check if the sequence has reached the maximum number of steps.
                if self.num_steps[group_id] == self.max_num_steps[group_id]:
                    self._free_seq(seq)
                    continue

246
247
248
        # Update the running sequences.
        running: List[SequenceGroup] = []
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
249
            if all(seq.status == SequenceStatus.FINISHED for seq in seq_group.seqs):
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
252
253
                del self.num_steps[seq_group.group_id]
                del self.max_num_steps[seq_group.group_id]
                del self.stop_token_ids[seq_group.group_id]
                # TODO: Return the seq_group to the client.
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255
256
257
258
259
                from transformers import AutoTokenizer
                tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')
                for seq in seq_group.seqs:
                    token_ids = seq.get_token_ids()
                    output = tokenizer.decode(token_ids, skip_special_tokens=True)
                    print(f'Seq {seq.seq_id}: {output}')
Woosuk Kwon's avatar
Woosuk Kwon committed
260
            else:
261
262
                running.append(seq_group)
        self.running = running