scheduler.py 10.6 KB
Newer Older
1
from typing import Dict, List
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3

from cacheflow.master.block_manager import BlockSpaceManager
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
8
9
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
13
14
from cacheflow.sequence import SequenceStatus


class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
15
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
16
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
17
        frontend: Frontend,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
        controllers: List,
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
21
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
22
        max_num_batched_tokens: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
23
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
24
        self.frontend = frontend
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
        self.controllers = controllers
        self.block_size = block_size
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
29
        self.max_num_batched_tokens = max_num_batched_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
30
31

        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
35
36
37
        self.block_manager = BlockSpaceManager(
            block_size=block_size,
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
        )

38
39
        # Running sequence groups (FIFO).
        self.running: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
        # Mapping: group_id -> num_steps.
        self.num_steps: Dict[int, int] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
        # Mapping: group_id -> sampling params.
        self.sampling_params: Dict[int, SamplingParams] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
46
47

        # Swapped sequence groups (LIFO).
        self.swapped: List[SequenceGroup] = []
        # Pending sequence groups (FIFO).
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
        self.pending: List[SequenceGroup] = []

Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54
55
    def _fetch_inputs(self) -> None:
        inputs = self.frontend.get_inputs()
        for seq_group, sampling_params in inputs:
            self.pending.append(seq_group)
            self.sampling_params[seq_group.group_id] = sampling_params

Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
58
59
60
61
62
63
    def _free_seq(self, seq: Sequence) -> None:
        seq.status = SequenceStatus.FINISHED
        self.block_manager.free(seq)

    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
        for seq in seq_group.seqs:
            seq.status = SequenceStatus.RUNNING
64
        self.running.append(seq_group)
65
        # FIXME(woosuk): Support interactive generation.
Woosuk Kwon's avatar
Woosuk Kwon committed
66
        self.num_steps[seq_group.group_id] = 0
Woosuk Kwon's avatar
Woosuk Kwon committed
67

68
69
70
    def _append(
        self,
        seq_group: SequenceGroup,
71
        blocks_to_copy: Dict[int, List[int]],
72
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74
75
76
77
78
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.FINISHED:
                continue
            ret = self.block_manager.append(seq)
            if ret is not None:
                src_block, dst_block = ret
79
80
81
82
                if src_block in blocks_to_copy:
                    blocks_to_copy[src_block].append(dst_block)
                else:
                    blocks_to_copy[src_block] = [dst_block]
Woosuk Kwon's avatar
Woosuk Kwon committed
83

84
85
86
87
88
    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
89
        mapping = self.block_manager.swap_in(seq_group)
90
        blocks_to_swap_in.update(mapping)
91
92
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING
93
        self.running.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
94

95
96
97
98
99
    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
100
        assert self.block_manager.can_swap_out(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
101
        mapping = self.block_manager.swap_out(seq_group)
102
        blocks_to_swap_out.update(mapping)
103
104
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
        self.swapped.append(seq_group)

Woosuk Kwon's avatar
Woosuk Kwon committed
107
    def step(self) -> None:
108
109
110
        # Blocks that need to be swaped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
111
        blocks_to_copy: Dict[int, List[int]] = {}
112

Woosuk Kwon's avatar
Woosuk Kwon committed
113
        # 1. Reserve new slots for the running sequences.
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
116
        # NOTE: Here we implicitly assume FCFS scheduling.
        # That is, the most recently added sequence group is the first
        # to be swapped out.
117
118
        victim_idx = len(self.running) - 1
        for i, seq_group in enumerate(self.running):
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
121
122
123
            if i > victim_idx:
                # The i-th sequence group has already been swapped out.
                break
            # OOM. Swap out the victim sequence groups.
            while not self.block_manager.can_append(seq_group):
124
                victim_seq_group = self.running[victim_idx]
125
                self._swap_out(victim_seq_group, blocks_to_swap_out)
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
129
130
                victim_idx -= 1
                if i > victim_idx:
                    # No other sequence groups can be swapped out.
                    break
            else:
131
                self._append(seq_group, blocks_to_copy)
132
        self.running = self.running[:victim_idx + 1]
Woosuk Kwon's avatar
Woosuk Kwon committed
133
134
135
136
137
138

        # 2. Swap in the swapped sequences if possible.
        # NOTE: Here we implicitly assume FCFS scheduling.
        # The swapped sequences are in LIFO order.
        for i, seq_group in enumerate(reversed(self.swapped)):
            if self.block_manager.can_swap_in(seq_group):
139
140
                self._swap_in(seq_group, blocks_to_swap_in)
                self._append(seq_group, blocks_to_copy)
Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
143
144
145
146
147
148
            else:
                # OOM. Stop swapping.
                self.swapped = self.swapped[:len(self.swapped) - i]
                break
        else:
            # All swapped sequences are swapped in.
            self.swapped.clear()

149
150
151
152
        # Ensure that swap-in and swap-out never happen at the same timestep.
        if blocks_to_swap_in:
            assert not blocks_to_swap_out

153
154
155
156
157
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
            for seq_group in self.running
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
        # 3. Join new sequences if possible.
        # NOTE: Here we implicitly assume FCFS scheduling.
160
        # TODO(woosuk): Add a batching policy to control the batch size.
Woosuk Kwon's avatar
Woosuk Kwon committed
161
        if not self.swapped:
Woosuk Kwon's avatar
Woosuk Kwon committed
162
            self._fetch_inputs()
Woosuk Kwon's avatar
Woosuk Kwon committed
163
            for i, seq_group in enumerate(self.pending):
164
                num_prompt_tokens = seq_group.seqs[0].get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
165
                if self.block_manager.can_allocate(seq_group):
166
                    if (num_batched_tokens + num_prompt_tokens
167
                        <= self.max_num_batched_tokens):
168
169
170
171
172
173
                        self._allocate(seq_group)
                        num_batched_tokens += num_prompt_tokens
                        continue

                self.pending = self.pending[i:]
                break
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
            else:
                self.pending.clear()
Woosuk Kwon's avatar
Woosuk Kwon committed
176

Woosuk Kwon's avatar
Woosuk Kwon committed
177
        # 4. Create input data structures.
178
        input_seq_groups: List[SequenceGroupInputs] = []
179
180
181
        for seq_group in self.running:
            group_id = seq_group.group_id
            num_steps = self.num_steps[group_id]
182

183
184
185
186
            # NOTE(woosuk): We assume that the number of steps is 0
            # for the prompt sequences.
            is_prompt = num_steps == 0

187
188
189
190
            input_tokens: Dict[int, List[int]] = {}
            seq_logprobs: Dict[int, float] = {}
            block_tables: Dict[int, List[int]] = {}
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
191
192
193
                seq_id = seq.seq_id
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
                if is_prompt:
194
                    input_tokens[seq_id] = seq.get_token_ids()
195
                else:
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
                    input_tokens[seq_id] = [seq.get_last_token_id()]
                seq_logprobs[seq_id] = seq.cumulative_logprobs
                # NOTE(woosuk): Sequences in the same group have the same
                # sequence length
                seq_len = seq.get_len()

            input_seq_group = SequenceGroupInputs(
                group_id=group_id,
                is_prompt=is_prompt,
                input_tokens=input_tokens,
                context_len=seq_len,
                seq_logprobs=seq_logprobs,
                sampling_params=self.sampling_params[group_id],
                block_tables=block_tables,
            )
            input_seq_groups.append(input_seq_group)
212

Woosuk Kwon's avatar
Woosuk Kwon committed
213
        # 5. Execute the first stage of the pipeline.
Woosuk Kwon's avatar
Woosuk Kwon committed
214
        self.controllers[0].execute_stage(
215
            input_seq_groups,
216
217
218
            blocks_to_swap_in,
            blocks_to_swap_out,
            blocks_to_copy,
Woosuk Kwon's avatar
Woosuk Kwon committed
219
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
220

Woosuk Kwon's avatar
Woosuk Kwon committed
221
222
    def post_step(
        self,
223
        seq_outputs: Dict[int, SequenceOutputs],
Woosuk Kwon's avatar
Woosuk Kwon committed
224
225
    ) -> None:
        # Update the running sequences and free blocks.
226
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
227
228
            group_id = seq_group.group_id
            self.num_steps[group_id] += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
229
            stop_token_ids = self.sampling_params[group_id].stop_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
230

231
            # Process beam search results before processing the next tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
234
235
            for seq in seq_group.seqs:
                if seq.status == SequenceStatus.FINISHED:
                    continue

236
237
                output = seq_outputs[seq.seq_id]
                if seq.seq_id != output.parent_seq_id:
Woosuk Kwon's avatar
Woosuk Kwon committed
238
239
240
241
                    # The sequence is a fork of the parent sequence (beam search).
                    # Free the current sequence.
                    self.block_manager.free(seq)
                    # Fork the parent sequence.
242
243
                    parent_seq = seq_group.find(output.parent_seq_id)
                    parent_seq.fork(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
                    self.block_manager.fork(parent_seq, seq)

246
247
248
249
250
            # Process the next tokens.
            for seq in seq_group.seqs:
                if seq.status == SequenceStatus.FINISHED:
                    continue

Woosuk Kwon's avatar
Woosuk Kwon committed
251
                # Append a new token to the sequence.
252
253
                output = seq_outputs[seq.seq_id]
                seq.append(output.output_token, output.logprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255

                # Check if the sequence has generated a stop token.
256
                if output.output_token in stop_token_ids:
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
260
                    self._free_seq(seq)
                    continue

                # Check if the sequence has reached the maximum number of steps.
Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
                max_num_steps = self.sampling_params[group_id].max_num_steps
                if self.num_steps[group_id] == max_num_steps:
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265
                    self._free_seq(seq)
                    continue

266
267
268
        # Update the running sequences.
        running: List[SequenceGroup] = []
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270
            if seq_group.is_finished():
                self._return(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
271
            else:
272
273
                running.append(seq_group)
        self.running = running
Woosuk Kwon's avatar
Woosuk Kwon committed
274
275
276
277
278
279

    def _return(self, seq_group: SequenceGroup) -> None:
        group_id = seq_group.group_id
        del self.num_steps[group_id]
        del self.sampling_params[group_id]
        self.frontend.print_response(seq_group)