"README_origin.md" did not exist on "196119b4cf6e9ba8a04385b5c137ac3021fa430d"
scheduler.py 15.2 KB
Newer Older
1
2
3
import enum
import time
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

from cacheflow.master.block_manager import BlockSpaceManager
6
from cacheflow.master.policy import PolicyFactory
Woosuk Kwon's avatar
Woosuk Kwon committed
7
from cacheflow.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
10
11
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
14
from cacheflow.sequence import SequenceStatus


15
16
17
18
19
20
21
22
23
24
25
26
27
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
30
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
31
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
32
        controllers: List,
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
36
        max_num_batched_tokens: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
37
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
40
41
        self.controllers = controllers
        self.block_size = block_size
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
42
        self.max_num_batched_tokens = max_num_batched_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
43

44
45
        # Instantiate the scheduling policy.
        self.policy = PolicyFactory.get_policy(policy_name='fcfs')
Woosuk Kwon's avatar
Woosuk Kwon committed
46
        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
49
50
51
52
        self.block_manager = BlockSpaceManager(
            block_size=block_size,
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
        )

53
54
55
        # Sequence groups in the WAITING state.
        self.waiting: List[SequenceGroup] = []
        # Sequence groups in the RUNNING state.
56
        self.running: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
        # Mapping: group_id -> num_steps.
        self.num_steps: Dict[int, int] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
        # Mapping: group_id -> sampling params.
        self.sampling_params: Dict[int, SamplingParams] = {}
61
        # Sequence groups in the SWAPPED state.
Woosuk Kwon's avatar
Woosuk Kwon committed
62
        self.swapped: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
63

64
65
    def add_sequence_groups(
        self,
66
        seq_groups: List[Tuple[SequenceGroup, SamplingParams]],
67
    ) -> None:
68
69
70
        # Add sequence groups to the waiting queue.
        for seq_group, sampling_params in seq_groups:
            self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
            self.sampling_params[seq_group.group_id] = sampling_params

73
    def _schedule(
74
        self,
75
    ) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]:
76
77
78
        # Blocks that need to be swaped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
79
        blocks_to_copy: Dict[int, List[int]] = {}
80

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        # Fix the current time.
        now = time.time()

        # NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
        # in order to minimize the preemption overheads.
        # Preemption happens only when there is no available slot to keep all
        # the sequence groups in the RUNNING state.
        # In this case, the policy is responsible for deciding which sequence
        # groups to preempt.
        self.running = self.policy.sort_by_priority(now, self.running)

        # Reserve new token slots for the running sequence groups.
        running: List[SequenceGroup] = []
        preempted: List[SequenceGroup] = []
        while self.running:
            seq_group = self.running.pop(0)
Woosuk Kwon's avatar
Woosuk Kwon committed
97
            while not self.block_manager.can_append(seq_group):
98
99
100
101
102
103
104
105
106
107
                if self.running:
                    # Preempt the lowest-priority sequence groups.
                    victim_seq_group = self.running.pop(-1)
                    self._preempt(victim_seq_group, blocks_to_swap_out)
                    preempted.append(victim_seq_group)
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
                    self._preempt(seq_group, blocks_to_swap_out)
                    preempted.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
                    break
            else:
110
                # Append new slots to the sequence group.
111
                self._append(seq_group, blocks_to_copy)
112
113
114
115
116
117
118
119
120
121
122
123
                running.append(seq_group)
        self.running = running

        # Swap in the sequence groups in the SWAPPED state if possible.
        self.swapped = self.policy.sort_by_priority(now, self.swapped)
        while self.swapped:
            seq_group = self.swapped[0]
            # If the sequence group has been preempted in this step, stop.
            if seq_group in preempted:
                break
            # If the sequence group cannot be swapped in, stop.
            if not self.block_manager.can_swap_in(seq_group):
Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
                break

126
127
128
129
            seq_group = self.swapped.pop(0)
            self._swap_in(seq_group, blocks_to_swap_in)
            self._append(seq_group, blocks_to_copy)
            self.running.append(seq_group)
130

131
132
133
134
135
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
            for seq_group in self.running
        )

136
137
138
139
140
141
        # Join waiting sequences if possible.
        prompt_group_ids: List[int] = []
        # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
        # prioritized over the sequence groups in the WAITING state.
        # This is because we want to bound the amount of CPU memory taken by
        # the swapped sequence groups.
Woosuk Kwon's avatar
Woosuk Kwon committed
142
        if not self.swapped:
143
144
145
146
147
148
149
150
151
152
153
            self.waiting = self.policy.sort_by_priority(now, self.waiting)
            while self.waiting:
                seq_group = self.waiting[0]
                # If the sequence group has been preempted in this step, stop.
                if seq_group in preempted:
                    break
                # If the sequence group cannot be allocated, stop.
                if not self.block_manager.can_allocate(seq_group):
                    break

                # If the number of batched tokens exceeds the limit, stop.
154
                num_prompt_tokens = seq_group.seqs[0].get_len()
155
156
157
158
159
160
161
162
163
                if (num_batched_tokens + num_prompt_tokens
                    > self.max_num_batched_tokens):
                    break

                seq_group = self.waiting.pop(0)
                self._allocate(seq_group)
                self.running.append(seq_group)
                num_batched_tokens += num_prompt_tokens
                prompt_group_ids.append(seq_group.group_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
164

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        return (blocks_to_swap_in,
                blocks_to_swap_out,
                blocks_to_copy,
                prompt_group_ids)

    def step(self) -> List[SequenceGroup]:
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
        scheduler_output = self._schedule()
        blocks_to_swap_in = scheduler_output[0]
        blocks_to_swap_out = scheduler_output[1]
        blocks_to_copy = scheduler_output[2]
        prompt_group_ids = scheduler_output[3]

        # Create input data structures.
181
        input_seq_groups: List[SequenceGroupInputs] = []
182
183
        updated_seq_groups: List[SequenceGroup] = self.running.copy()

184
185
        for seq_group in self.running:
            group_id = seq_group.group_id
186
            is_prompt = group_id in prompt_group_ids
187

188
189
190
191
            input_tokens: Dict[int, List[int]] = {}
            seq_logprobs: Dict[int, float] = {}
            block_tables: Dict[int, List[int]] = {}
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
192
193
194
                seq_id = seq.seq_id
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
                if is_prompt:
195
                    input_tokens[seq_id] = seq.get_token_ids()
196
                else:
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
                    input_tokens[seq_id] = [seq.get_last_token_id()]
                seq_logprobs[seq_id] = seq.cumulative_logprobs
                # NOTE(woosuk): Sequences in the same group have the same
                # sequence length
                seq_len = seq.get_len()

            input_seq_group = SequenceGroupInputs(
                group_id=group_id,
                is_prompt=is_prompt,
                input_tokens=input_tokens,
                context_len=seq_len,
                seq_logprobs=seq_logprobs,
                sampling_params=self.sampling_params[group_id],
                block_tables=block_tables,
            )
            input_seq_groups.append(input_seq_group)
213

214
215
216
217
        # Execute the first stage of the pipeline.
        if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out:
            # Swap in and swap out should never happen at the same time.
            assert not (blocks_to_swap_in and blocks_to_swap_out)
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
            self.controllers[0].execute_stage(
                input_seq_groups,
220
221
222
                blocks_to_swap_in=blocks_to_swap_in,
                blocks_to_swap_out=blocks_to_swap_out,
                blocks_to_copy=blocks_to_copy,
Woosuk Kwon's avatar
Woosuk Kwon committed
223
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
224

225
226
        return updated_seq_groups

Woosuk Kwon's avatar
Woosuk Kwon committed
227
228
    def post_step(
        self,
229
        seq_outputs: Dict[int, SequenceOutputs],
Woosuk Kwon's avatar
Woosuk Kwon committed
230
231
    ) -> None:
        # Update the running sequences and free blocks.
232
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234
            group_id = seq_group.group_id
            self.num_steps[group_id] += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
235
            stop_token_ids = self.sampling_params[group_id].stop_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
236

237
            # Process beam search results before processing the next tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
238
239
240
241
            for seq in seq_group.seqs:
                if seq.status == SequenceStatus.FINISHED:
                    continue

242
243
                output = seq_outputs[seq.seq_id]
                if seq.seq_id != output.parent_seq_id:
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
246
247
                    # The sequence is a fork of the parent sequence (beam search).
                    # Free the current sequence.
                    self.block_manager.free(seq)
                    # Fork the parent sequence.
248
249
                    parent_seq = seq_group.find(output.parent_seq_id)
                    parent_seq.fork(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
                    self.block_manager.fork(parent_seq, seq)

252
253
254
255
256
            # Process the next tokens.
            for seq in seq_group.seqs:
                if seq.status == SequenceStatus.FINISHED:
                    continue

Woosuk Kwon's avatar
Woosuk Kwon committed
257
                # Append a new token to the sequence.
258
259
                output = seq_outputs[seq.seq_id]
                seq.append(output.output_token, output.logprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
260
261

                # Check if the sequence has generated a stop token.
262
                if output.output_token in stop_token_ids:
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265
266
                    self._free_seq(seq)
                    continue

                # Check if the sequence has reached the maximum number of steps.
Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
                max_num_steps = self.sampling_params[group_id].max_num_steps
                if self.num_steps[group_id] == max_num_steps:
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270
271
                    self._free_seq(seq)
                    continue

272
273
274
        # Update the running sequences.
        running: List[SequenceGroup] = []
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
275
            if seq_group.is_finished():
276
                self._free_seq_group(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
277
            else:
278
279
                running.append(seq_group)
        self.running = running
Woosuk Kwon's avatar
Woosuk Kwon committed
280

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
        for seq in seq_group.seqs:
            seq.status = SequenceStatus.RUNNING
        # FIXME(woosuk): Support interactive generation.
        if seq_group.group_id not in self.num_steps:
            self.num_steps[seq_group.group_id] = 0

    def _append(
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, List[int]],
    ) -> None:
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            ret = self.block_manager.append(seq)
            if ret is not None:
                src_block, dst_block = ret
                if src_block in blocks_to_copy:
                    blocks_to_copy[src_block].append(dst_block)
                else:
                    blocks_to_copy[src_block] = [dst_block]

    def _preempt(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
        preemption_mode: Optional[PreemptionMode] = None,
    ) -> None:
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
        # (e.g., beam search), recomputation is not supported. In such a case,
        # we use swapping instead.
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
        if preemption_mode is None:
            seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
            if len(seqs) == 1:
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
            assert False, 'Invalid preemption mode.'

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
            self.block_manager.free(seq)
        self.waiting.append(seq_group)

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        for seq in seqs:
            seq.status = SequenceStatus.SWAPPED
        self._swap_out(seq_group, blocks_to_swap_out)
        self.swapped.append(seq_group)

    def _free_seq(self, seq: Sequence) -> None:
        seq.status = SequenceStatus.FINISHED
        self.block_manager.free(seq)

359
    def _free_seq_group(self, seq_group: SequenceGroup) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
360
361
362
        group_id = seq_group.group_id
        del self.num_steps[group_id]
        del self.sampling_params[group_id]
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383

    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
        blocks_to_swap_in.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        assert self.block_manager.can_swap_out(seq_group)
        mapping = self.block_manager.swap_out(seq_group)
        blocks_to_swap_out.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED