scheduler.py 18.4 KB
Newer Older
1
2
import enum
import time
3
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
4

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.block_manager import BlockSpaceManager
from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
                           SequenceGroupMetadata, SequenceOutputs,
                           SequenceStatus)
Woosuk Kwon's avatar
Woosuk Kwon committed
12

Woosuk Kwon's avatar
Woosuk Kwon committed
13
logger = init_logger(__name__)
14

15
_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17


18
19
20
21
22
23
24
25
26
27
28
29
30
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class SchedulerOutputs:

    def __init__(
        self,
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, List[int]],
    ) -> None:
        self.blocks_to_swap_in = blocks_to_swap_in
        self.blocks_to_swap_out = blocks_to_swap_out
        self.blocks_to_copy = blocks_to_copy
        # Swap in and swap out should never happen at the same time.
        assert not (blocks_to_swap_in and blocks_to_swap_out)

    def is_empty(self) -> bool:
        return (not self.blocks_to_swap_in
                and not self.blocks_to_swap_out
                and not self.blocks_to_copy)


Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
53
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
54
        self,
55
56
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
57
        log_stats: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
58
    ) -> None:
59
60
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
Woosuk Kwon's avatar
Woosuk Kwon committed
61
        self.log_stats = log_stats
Woosuk Kwon's avatar
Woosuk Kwon committed
62

63
64
        # Instantiate the scheduling policy.
        self.policy = PolicyFactory.get_policy(policy_name='fcfs')
Woosuk Kwon's avatar
Woosuk Kwon committed
65
        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
66
        self.block_manager = BlockSpaceManager(
67
68
69
            block_size=self.cache_config.block_size,
            num_gpu_blocks=self.cache_config.num_gpu_blocks,
            num_cpu_blocks=self.cache_config.num_cpu_blocks,
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
        )

72
73
74
        # Sequence groups in the WAITING state.
        self.waiting: List[SequenceGroup] = []
        # Sequence groups in the RUNNING state.
75
        self.running: List[SequenceGroup] = []
76
        # Sequence groups in the SWAPPED state.
Woosuk Kwon's avatar
Woosuk Kwon committed
77
        self.swapped: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
78

Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
        self.last_logging_time: float = 0.0
        # List[timestamp, num_tokens]
        self.num_input_tokens: List[Tuple[float, int]] = []
82

83
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
84
        # Add sequence groups to the waiting queue.
85
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
86

87
88
89
90
91
92
93
94
95
96
97
98
    def abort_seq_group(self, request_id: str) -> None:
        for state_queue in [self.waiting, self.running, self.swapped]:
            for seq_group in state_queue:
                if seq_group.request_id == request_id:
                    # Remove the sequence group from the state queue.
                    state_queue.remove(seq_group)
                    for seq in seq_group.seqs:
                        if seq.is_finished():
                            continue
                        self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
                    return

99
100
101
    def has_unfinished_seqs(self) -> bool:
        return self.waiting or self.running or self.swapped

102
103
104
    def get_num_unfinished_seq_groups(self) -> int:
        return len(self.waiting) + len(self.running) + len(self.swapped)

Lily Liu's avatar
Lily Liu committed
105
    def _schedule(self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
106
107
108
        # Blocks that need to be swaped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
109
        blocks_to_copy: Dict[int, List[int]] = {}
Lily Liu's avatar
Lily Liu committed
110
        ignored_seq_groups: List[SequenceGroup] = []
111

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        # Fix the current time.
        now = time.time()

        # NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
        # in order to minimize the preemption overheads.
        # Preemption happens only when there is no available slot to keep all
        # the sequence groups in the RUNNING state.
        # In this case, the policy is responsible for deciding which sequence
        # groups to preempt.
        self.running = self.policy.sort_by_priority(now, self.running)

        # Reserve new token slots for the running sequence groups.
        running: List[SequenceGroup] = []
        preempted: List[SequenceGroup] = []
        while self.running:
            seq_group = self.running.pop(0)
128
            while not self.block_manager.can_append_slot(seq_group):
129
130
131
132
133
134
135
136
137
138
                if self.running:
                    # Preempt the lowest-priority sequence groups.
                    victim_seq_group = self.running.pop(-1)
                    self._preempt(victim_seq_group, blocks_to_swap_out)
                    preempted.append(victim_seq_group)
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
                    self._preempt(seq_group, blocks_to_swap_out)
                    preempted.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
                    break
            else:
141
                # Append new slots to the sequence group.
142
                self._append_slot(seq_group, blocks_to_copy)
143
144
145
146
147
                running.append(seq_group)
        self.running = running

        # Swap in the sequence groups in the SWAPPED state if possible.
        self.swapped = self.policy.sort_by_priority(now, self.swapped)
148
        while self.swapped and not blocks_to_swap_out:
149
150
151
152
153
154
            seq_group = self.swapped[0]
            # If the sequence group has been preempted in this step, stop.
            if seq_group in preempted:
                break
            # If the sequence group cannot be swapped in, stop.
            if not self.block_manager.can_swap_in(seq_group):
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
                break

157
158
            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
159
            num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
160
161
162
            num_curr_seqs = sum(
                seq_group.num_seqs(status=SequenceStatus.RUNNING)
                for seq_group in self.running)
163
            if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
164
165
                break

166
167
            seq_group = self.swapped.pop(0)
            self._swap_in(seq_group, blocks_to_swap_in)
168
            self._append_slot(seq_group, blocks_to_copy)
169
            self.running.append(seq_group)
170

171
172
173
174
175
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
            for seq_group in self.running
        )

176
        # Join waiting sequences if possible.
177
        prompt_group_ids: List[str] = []
178
179
180
181
        # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
        # prioritized over the sequence groups in the WAITING state.
        # This is because we want to bound the amount of CPU memory taken by
        # the swapped sequence groups.
Woosuk Kwon's avatar
Woosuk Kwon committed
182
        if not self.swapped:
183
184
185
            # Optimization: We do not sort the waiting queue since the preempted
            # sequence groups are added to the front and the new sequence groups
            # are added to the back.
186
187
188
189
190
            while self.waiting:
                seq_group = self.waiting[0]
                # If the sequence group has been preempted in this step, stop.
                if seq_group in preempted:
                    break
Lily Liu's avatar
Lily Liu committed
191
192
193
194
195
196
197
198
199
200
201
202
203

                num_prompt_tokens = seq_group.get_seqs()[0].get_len()
                if num_prompt_tokens >= self.scheduler_config.max_seq_len:
                    logger.warn(
                        f"Input prompt ({num_prompt_tokens} tokens) is too long"
                        " and exceeds limit of "
                        f"{self.scheduler_config.max_seq_len}")
                    for seq in seq_group.get_seqs():
                        seq.status = SequenceStatus.FINISHED_IGNORED
                    ignored_seq_groups.append(seq_group)
                    self.waiting.pop(0)
                    break

204
205
206
207
208
209
                # If the sequence group cannot be allocated, stop.
                if not self.block_manager.can_allocate(seq_group):
                    break

                # If the number of batched tokens exceeds the limit, stop.
                if (num_batched_tokens + num_prompt_tokens
210
                    > self.scheduler_config.max_num_batched_tokens):
211
212
                    break

213
214
                # The total number of sequences in the RUNNING state should not
                # exceed the maximum number of sequences.
215
                num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
216
217
218
                num_curr_seqs = sum(
                    seq_group.num_seqs(status=SequenceStatus.RUNNING)
                    for seq_group in self.running)
219
                if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
220
221
                    break

222
223
224
225
                seq_group = self.waiting.pop(0)
                self._allocate(seq_group)
                self.running.append(seq_group)
                num_batched_tokens += num_prompt_tokens
226
                prompt_group_ids.append(seq_group.request_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
227

228
229
230
231
232
        scheduler_outputs = SchedulerOutputs(
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
233
        if not self.log_stats:
Lily Liu's avatar
Lily Liu committed
234
            return scheduler_outputs, prompt_group_ids, ignored_seq_groups
Woosuk Kwon's avatar
Woosuk Kwon committed
235

Zhuohan Li's avatar
Zhuohan Li committed
236
        # TODO(woosuk): Move the below code to the engine.
Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        now = time.time()
        if num_batched_tokens > 0:
            self.num_input_tokens.append((now, num_batched_tokens))
        elapsed_time = now - self.last_logging_time
        if elapsed_time > _LOGGING_INTERVAL_SEC:
            self.last_logging_time = now
            self.num_input_tokens = [
                (t, n) for t, n in self.num_input_tokens
                if now - t < _LOGGING_INTERVAL_SEC
            ]
            if len(self.num_input_tokens) > 1:
                total_num_tokens = sum(n for _, n in self.num_input_tokens[:-1])
                window = now - self.num_input_tokens[0][0]
                avg_throughput = total_num_tokens / window
            else:
                avg_throughput = 0.0

254
            total_num_gpu_blocks = self.cache_config.num_gpu_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
255
            num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
256
257
258
259
260
            num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
            gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks

            total_num_cpu_blocks = self.cache_config.num_cpu_blocks
            if total_num_cpu_blocks > 0:
261
                num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
262
263
                num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
                cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
264
265
266
267
268
269
270
271
272
273
            else:
                cpu_cache_usage = 0.0

            logger.info(
                f"Throughput: {avg_throughput:.1f} tokens/s, "
                f"Running: {len(self.running)} reqs, "
                f"Swapped: {len(self.swapped)} reqs, "
                f"Pending: {len(self.waiting)} reqs, "
                f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
                f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
Lily Liu's avatar
Lily Liu committed
274
        return scheduler_outputs, prompt_group_ids, ignored_seq_groups
Woosuk Kwon's avatar
Woosuk Kwon committed
275

Lily Liu's avatar
Lily Liu committed
276
    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[SequenceGroup]]:
277
278
279
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
Lily Liu's avatar
Lily Liu committed
280
        scheduler_outputs, prompt_group_ids, ignored_seq_groups = self._schedule()
281
282

        # Create input data structures.
283
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
284
        for seq_group in self.running:
285
            is_prompt = seq_group.request_id in prompt_group_ids
286

287
            seq_data: Dict[int, List[SequenceData]] = {}
288
289
            block_tables: Dict[int, List[int]] = {}
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
290
                seq_id = seq.seq_id
291
                seq_data[seq_id] = seq.data
292
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
293

294
            seq_group_metadata = SequenceGroupMetadata(
295
                request_id=seq_group.request_id,
296
                is_prompt=is_prompt,
297
                seq_data=seq_data,
298
                sampling_params=seq_group.sampling_params,
299
300
                block_tables=block_tables,
            )
301
            seq_group_metadata_list.append(seq_group_metadata)
Lily Liu's avatar
Lily Liu committed
302
        return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups
303

304
    def update(
Woosuk Kwon's avatar
Woosuk Kwon committed
305
        self,
306
        seq_outputs: Dict[int, SequenceOutputs],
307
    ) -> List[SequenceGroup]:
Woosuk Kwon's avatar
Woosuk Kwon committed
308
        # Update the running sequences and free blocks.
309
        for seq_group in self.running:
310
311
            # Process beam search results before processing the new tokens.
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
312
313
                output = seq_outputs[seq.seq_id]
                if seq.seq_id != output.parent_seq_id:
Woosuk Kwon's avatar
Woosuk Kwon committed
314
315
316
317
                    # The sequence is a fork of the parent sequence (beam search).
                    # Free the current sequence.
                    self.block_manager.free(seq)
                    # Fork the parent sequence.
318
319
                    parent_seq = seq_group.find(output.parent_seq_id)
                    parent_seq.fork(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
320
321
                    self.block_manager.fork(parent_seq, seq)

322
323
            # Process the new tokens.
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
Woosuk Kwon's avatar
Woosuk Kwon committed
324
                # Append a new token to the sequence.
325
                output = seq_outputs[seq.seq_id]
326
                seq.append_token_id(output.output_token, output.logprobs)
Zhuohan Li's avatar
Zhuohan Li committed
327
328
        # Return a shallow copy of the running queue to prevent the queue
        # from being modified by the caller.
329
        return self.running.copy()
Woosuk Kwon's avatar
Woosuk Kwon committed
330

Zhuohan Li's avatar
Zhuohan Li committed
331
332
    def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
        seq.status = finish_status
333
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
334

335
336
337
338
339
    def free_finished_seq_groups(self) -> None:
        self.running = [
            seq_group for seq_group in self.running
            if not seq_group.is_finished()
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
340

341
342
    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
343
        for seq in seq_group.get_seqs():
344
345
            seq.status = SequenceStatus.RUNNING

346
    def _append_slot(
347
348
349
350
351
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, List[int]],
    ) -> None:
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
352
            ret = self.block_manager.append_slot(seq)
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
            if ret is not None:
                src_block, dst_block = ret
                if src_block in blocks_to_copy:
                    blocks_to_copy[src_block].append(dst_block)
                else:
                    blocks_to_copy[src_block] = [dst_block]

    def _preempt(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
        preemption_mode: Optional[PreemptionMode] = None,
    ) -> None:
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
        # (e.g., beam search), recomputation is not supported. In such a case,
        # we use swapping instead.
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
        if preemption_mode is None:
            seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
            if len(seqs) == 1:
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
            assert False, 'Invalid preemption mode.'

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
            self.block_manager.free(seq)
399
400
401
        # NOTE: For FCFS, we insert the preempted sequence group to the front
        # of the waiting queue.
        self.waiting.insert(0, seq_group)
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        for seq in seqs:
            seq.status = SequenceStatus.SWAPPED
        self._swap_out(seq_group, blocks_to_swap_out)
        self.swapped.append(seq_group)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
        blocks_to_swap_in.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
429
430
431
432
433
434
        if not self.block_manager.can_swap_out(seq_group):
            # FIXME(woosuk): Abort the sequence group instead of aborting the
            # entire engine.
            raise RuntimeError(
                "Aborted due to the lack of CPU swap space. Please increase "
                "the swap space to avoid this error.")
435
436
437
438
        mapping = self.block_manager.swap_out(seq_group)
        blocks_to_swap_out.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED