scheduler.py 18 KB
Newer Older
1
import enum
2
3
import os
import pickle
4
import time
5
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
6

7
from cacheflow.core.block_manager import BlockSpaceManager
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from cacheflow.logger import init_logger
9
from cacheflow.core.policy import PolicyFactory
Woosuk Kwon's avatar
Woosuk Kwon committed
10
from cacheflow.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
13
from cacheflow.sequence import SequenceGroupMetadata
14
from cacheflow.sequence import SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17
from cacheflow.sequence import SequenceStatus


Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20
21
logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 10


22
23
24
25
26
27
28
29
30
31
32
33
34
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
37
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
38
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
39
        controllers: List,
Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
43
        max_num_batched_tokens: int,
44
        max_num_sequences: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
45
        log_stats: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
46
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
49
50
        self.controllers = controllers
        self.block_size = block_size
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
51
        self.max_num_batched_tokens = max_num_batched_tokens
52
        self.max_num_sequences = max_num_sequences
Woosuk Kwon's avatar
Woosuk Kwon committed
53
        self.log_stats = log_stats
Woosuk Kwon's avatar
Woosuk Kwon committed
54

55
56
        # Instantiate the scheduling policy.
        self.policy = PolicyFactory.get_policy(policy_name='fcfs')
Woosuk Kwon's avatar
Woosuk Kwon committed
57
        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
61
62
63
        self.block_manager = BlockSpaceManager(
            block_size=block_size,
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
        )

64
65
66
        # Sequence groups in the WAITING state.
        self.waiting: List[SequenceGroup] = []
        # Sequence groups in the RUNNING state.
67
        self.running: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
        # Mapping: group_id -> num_steps.
        self.num_steps: Dict[int, int] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
        # Mapping: group_id -> sampling params.
        self.sampling_params: Dict[int, SamplingParams] = {}
72
        # Sequence groups in the SWAPPED state.
Woosuk Kwon's avatar
Woosuk Kwon committed
73
        self.swapped: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
74

Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
        self.last_logging_time: float = 0.0
        # List[timestamp, num_tokens]
        self.num_input_tokens: List[Tuple[float, int]] = []
78

79
80
    def add_sequence_groups(
        self,
81
        seq_groups: List[Tuple[SequenceGroup, SamplingParams]],
82
    ) -> None:
83
84
85
        # Add sequence groups to the waiting queue.
        for seq_group, sampling_params in seq_groups:
            self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
            self.sampling_params[seq_group.group_id] = sampling_params

88
    def _schedule(
89
        self,
90
    ) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]:
91
92
93
        # Blocks that need to be swaped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
94
        blocks_to_copy: Dict[int, List[int]] = {}
95

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        # Fix the current time.
        now = time.time()

        # NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
        # in order to minimize the preemption overheads.
        # Preemption happens only when there is no available slot to keep all
        # the sequence groups in the RUNNING state.
        # In this case, the policy is responsible for deciding which sequence
        # groups to preempt.
        self.running = self.policy.sort_by_priority(now, self.running)

        # Reserve new token slots for the running sequence groups.
        running: List[SequenceGroup] = []
        preempted: List[SequenceGroup] = []
        while self.running:
            seq_group = self.running.pop(0)
112
            while not self.block_manager.can_append_slot(seq_group):
113
114
115
116
117
118
119
120
121
122
                if self.running:
                    # Preempt the lowest-priority sequence groups.
                    victim_seq_group = self.running.pop(-1)
                    self._preempt(victim_seq_group, blocks_to_swap_out)
                    preempted.append(victim_seq_group)
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
                    self._preempt(seq_group, blocks_to_swap_out)
                    preempted.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
123
124
                    break
            else:
125
                # Append new slots to the sequence group.
126
                self._append_slot(seq_group, blocks_to_copy)
127
128
129
130
131
                running.append(seq_group)
        self.running = running

        # Swap in the sequence groups in the SWAPPED state if possible.
        self.swapped = self.policy.sort_by_priority(now, self.swapped)
132
133
        # FCFS
        while self.swapped and not blocks_to_swap_out:
134
135
136
137
138
139
            seq_group = self.swapped[0]
            # If the sequence group has been preempted in this step, stop.
            if seq_group in preempted:
                break
            # If the sequence group cannot be swapped in, stop.
            if not self.block_manager.can_swap_in(seq_group):
Woosuk Kwon's avatar
Woosuk Kwon committed
140
141
                break

142
143
144
145
146
147
            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
            num_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
            if len(self.running) + num_seqs > self.max_num_sequences:
                break

148
149
            seq_group = self.swapped.pop(0)
            self._swap_in(seq_group, blocks_to_swap_in)
150
            self._append_slot(seq_group, blocks_to_copy)
151
            self.running.append(seq_group)
152

153
154
155
156
157
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
            for seq_group in self.running
        )

158
159
160
161
162
163
        # Join waiting sequences if possible.
        prompt_group_ids: List[int] = []
        # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
        # prioritized over the sequence groups in the WAITING state.
        # This is because we want to bound the amount of CPU memory taken by
        # the swapped sequence groups.
Woosuk Kwon's avatar
Woosuk Kwon committed
164
        if not self.swapped:
165
166
167
168
169
170
171
172
173
174
175
            self.waiting = self.policy.sort_by_priority(now, self.waiting)
            while self.waiting:
                seq_group = self.waiting[0]
                # If the sequence group has been preempted in this step, stop.
                if seq_group in preempted:
                    break
                # If the sequence group cannot be allocated, stop.
                if not self.block_manager.can_allocate(seq_group):
                    break

                # If the number of batched tokens exceeds the limit, stop.
176
                num_prompt_tokens = seq_group.seqs[0].get_len()
177
178
179
180
                if (num_batched_tokens + num_prompt_tokens
                    > self.max_num_batched_tokens):
                    break

181
182
183
184
185
186
                # The total number of sequences in the RUNNING state should not
                # exceed the maximum number of sequences.
                num_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
                if len(self.running) + num_seqs > self.max_num_sequences:
                    break

187
188
189
190
191
                seq_group = self.waiting.pop(0)
                self._allocate(seq_group)
                self.running.append(seq_group)
                num_batched_tokens += num_prompt_tokens
                prompt_group_ids.append(seq_group.group_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
192

Woosuk Kwon's avatar
Woosuk Kwon committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        if not self.log_stats:
            return (blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy,
                    prompt_group_ids)

        now = time.time()
        if num_batched_tokens > 0:
            self.num_input_tokens.append((now, num_batched_tokens))
        elapsed_time = now - self.last_logging_time
        if elapsed_time > _LOGGING_INTERVAL_SEC:
            self.last_logging_time = now
            self.num_input_tokens = [
                (t, n) for t, n in self.num_input_tokens
                if now - t < _LOGGING_INTERVAL_SEC
            ]
            if len(self.num_input_tokens) > 1:
                total_num_tokens = sum(n for _, n in self.num_input_tokens[:-1])
                window = now - self.num_input_tokens[0][0]
                avg_throughput = total_num_tokens / window
            else:
                avg_throughput = 0.0

            num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
            num_used_gpu_blocks = self.num_gpu_blocks - num_free_gpu_blocks
            gpu_cache_usage = num_used_gpu_blocks / self.num_gpu_blocks
            if self.num_cpu_blocks > 0:
218
219
                num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
                num_used_cpu_blocks = self.num_cpu_blocks - num_free_cpu_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
220
221
222
223
224
225
226
227
228
229
230
231
232
                cpu_cache_usage = num_used_cpu_blocks / self.num_cpu_blocks
            else:
                cpu_cache_usage = 0.0

            logger.info(
                f"Throughput: {avg_throughput:.1f} tokens/s, "
                f"Running: {len(self.running)} reqs, "
                f"Swapped: {len(self.swapped)} reqs, "
                f"Pending: {len(self.waiting)} reqs, "
                f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
                f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")

        return (blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy,
233
234
235
236
237
238
239
240
241
242
243
244
245
                prompt_group_ids)

    def step(self) -> List[SequenceGroup]:
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
        scheduler_output = self._schedule()
        blocks_to_swap_in = scheduler_output[0]
        blocks_to_swap_out = scheduler_output[1]
        blocks_to_copy = scheduler_output[2]
        prompt_group_ids = scheduler_output[3]

        # Create input data structures.
246
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
247
248
        updated_seq_groups: List[SequenceGroup] = self.running.copy()

249
250
        for seq_group in self.running:
            group_id = seq_group.group_id
251
            is_prompt = group_id in prompt_group_ids
252

253
254
255
256
            input_tokens: Dict[int, List[int]] = {}
            seq_logprobs: Dict[int, float] = {}
            block_tables: Dict[int, List[int]] = {}
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
257
258
259
                seq_id = seq.seq_id
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
                if is_prompt:
260
                    input_tokens[seq_id] = seq.get_token_ids()
261
                else:
262
263
264
265
266
267
                    input_tokens[seq_id] = [seq.get_last_token_id()]
                seq_logprobs[seq_id] = seq.cumulative_logprobs
                # NOTE(woosuk): Sequences in the same group have the same
                # sequence length
                seq_len = seq.get_len()

268
            seq_group_metadata = SequenceGroupMetadata(
269
270
271
272
273
274
275
276
                group_id=group_id,
                is_prompt=is_prompt,
                input_tokens=input_tokens,
                context_len=seq_len,
                seq_logprobs=seq_logprobs,
                sampling_params=self.sampling_params[group_id],
                block_tables=block_tables,
            )
277
            seq_group_metadata_list.append(seq_group_metadata)
278

279
        # Execute the first stage of the pipeline.
280
        if seq_group_metadata_list or blocks_to_swap_in or blocks_to_swap_out:
281
282
            # Swap in and swap out should never happen at the same time.
            assert not (blocks_to_swap_in and blocks_to_swap_out)
Woosuk Kwon's avatar
Woosuk Kwon committed
283
            self.controllers[0].execute_stage(
284
                seq_group_metadata_list,
285
286
287
                blocks_to_swap_in=blocks_to_swap_in,
                blocks_to_swap_out=blocks_to_swap_out,
                blocks_to_copy=blocks_to_copy,
Woosuk Kwon's avatar
Woosuk Kwon committed
288
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
289

290
291
        return updated_seq_groups

Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
    def post_step(
        self,
294
        seq_outputs: Dict[int, SequenceOutputs],
Woosuk Kwon's avatar
Woosuk Kwon committed
295
296
    ) -> None:
        # Update the running sequences and free blocks.
297
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
            group_id = seq_group.group_id
            self.num_steps[group_id] += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
300
            stop_token_ids = self.sampling_params[group_id].stop_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
301

302
            # Process beam search results before processing the next tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
305
306
            for seq in seq_group.seqs:
                if seq.status == SequenceStatus.FINISHED:
                    continue

307
308
                output = seq_outputs[seq.seq_id]
                if seq.seq_id != output.parent_seq_id:
Woosuk Kwon's avatar
Woosuk Kwon committed
309
310
311
312
                    # The sequence is a fork of the parent sequence (beam search).
                    # Free the current sequence.
                    self.block_manager.free(seq)
                    # Fork the parent sequence.
313
314
                    parent_seq = seq_group.find(output.parent_seq_id)
                    parent_seq.fork(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
315
316
                    self.block_manager.fork(parent_seq, seq)

317
318
319
320
321
            # Process the next tokens.
            for seq in seq_group.seqs:
                if seq.status == SequenceStatus.FINISHED:
                    continue

Woosuk Kwon's avatar
Woosuk Kwon committed
322
                # Append a new token to the sequence.
323
                output = seq_outputs[seq.seq_id]
324
                seq.append_token(output.output_token, output.logprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
325
326

                # Check if the sequence has generated a stop token.
327
                if output.output_token in stop_token_ids:
Woosuk Kwon's avatar
Woosuk Kwon committed
328
329
330
331
                    self._free_seq(seq)
                    continue

                # Check if the sequence has reached the maximum number of steps.
Woosuk Kwon's avatar
Woosuk Kwon committed
332
333
                max_num_steps = self.sampling_params[group_id].max_num_steps
                if self.num_steps[group_id] == max_num_steps:
Woosuk Kwon's avatar
Woosuk Kwon committed
334
335
336
                    self._free_seq(seq)
                    continue

337
338
339
        # Update the running sequences.
        running: List[SequenceGroup] = []
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
340
            if seq_group.is_finished():
341
                self._free_seq_group(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
342
            else:
343
344
                running.append(seq_group)
        self.running = running
Woosuk Kwon's avatar
Woosuk Kwon committed
345

346
347
348
349
350
351
352
353
    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
        for seq in seq_group.seqs:
            seq.status = SequenceStatus.RUNNING
        # FIXME(woosuk): Support interactive generation.
        if seq_group.group_id not in self.num_steps:
            self.num_steps[seq_group.group_id] = 0

354
    def _append_slot(
355
356
357
358
359
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, List[int]],
    ) -> None:
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
360
            ret = self.block_manager.append_slot(seq)
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
            if ret is not None:
                src_block, dst_block = ret
                if src_block in blocks_to_copy:
                    blocks_to_copy[src_block].append(dst_block)
                else:
                    blocks_to_copy[src_block] = [dst_block]

    def _preempt(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
        preemption_mode: Optional[PreemptionMode] = None,
    ) -> None:
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
        # (e.g., beam search), recomputation is not supported. In such a case,
        # we use swapping instead.
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
        if preemption_mode is None:
            seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
            if len(seqs) == 1:
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
            assert False, 'Invalid preemption mode.'

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
            self.block_manager.free(seq)
        self.waiting.append(seq_group)

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        for seq in seqs:
            seq.status = SequenceStatus.SWAPPED
        self._swap_out(seq_group, blocks_to_swap_out)
        self.swapped.append(seq_group)

    def _free_seq(self, seq: Sequence) -> None:
        seq.status = SequenceStatus.FINISHED
        self.block_manager.free(seq)

424
    def _free_seq_group(self, seq_group: SequenceGroup) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
425
426
427
        group_id = seq_group.group_id
        del self.num_steps[group_id]
        del self.sampling_params[group_id]
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448

    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
        blocks_to_swap_in.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        assert self.block_manager.can_swap_out(seq_group)
        mapping = self.block_manager.swap_out(seq_group)
        blocks_to_swap_out.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED