scheduler.py 25.3 KB
Newer Older
1
2
import enum
import time
3
from collections import deque
4
from dataclasses import dataclass
5
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
6

7
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
8
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger
11
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
13
                           SequenceGroupMetadata, SequenceStatus)
Woosuk Kwon's avatar
Woosuk Kwon committed
14

Woosuk Kwon's avatar
Woosuk Kwon committed
15
logger = init_logger(__name__)
16

Woosuk Kwon's avatar
Woosuk Kwon committed
17

18
19
20
21
22
23
24
25
26
27
28
29
30
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


31
32
33
34
35
36
37
38
39
40
41
42
43
# seq_group: SequenceGroup to schedule.
# token_chunk_size: The number of prefill tokens to be processed in the next
# step.
@dataclass
class ScheduledSequenceGroup:
    # A sequence group that's scheduled.
    seq_group: SequenceGroup
    # The total chunk size (number of tokens) to process for next iteration.
    # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
    # chunked, it can be smaller than that.
    token_chunk_size: int


44
45
46
47
class SchedulerOutputs:

    def __init__(
        self,
48
        scheduled_seq_groups: Iterable[ScheduledSequenceGroup],
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
        prompt_run: bool,
        num_batched_tokens: int,
51
52
53
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, List[int]],
Woosuk Kwon's avatar
Woosuk Kwon committed
54
        ignored_seq_groups: List[SequenceGroup],
55
    ) -> None:
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        """A list of sequence groups to be scheduled as a single batch.

        Args:
            scheduled_seq_groups: A tuple of scheduled sequence group and its
                token chunk size.
            prompt_run: True if all sequence groups are in prefill phase.
                If False, all sequence groups are in decoding phase.
            num_batched_tokens: Total number of batched tokens.
            blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block
                number.
            blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block
                number.
            blocks_to_copy: Blocks to copy. Source to a list of dest blocks.
            ignored_seq_groups: Sequence groups that are going to be ignored.
        """
        # A tuple of scheduled sequence group and its chunk size.
        self.scheduled_seq_groups: ScheduledSequenceGroup = scheduled_seq_groups
        # True if all sequence groups are in prefill phase. If False, all
        # sequence groups are in decoding phase.
        self.prompt_run: bool = prompt_run
        # Total number of batched tokens.
        self.num_batched_tokens: int = num_batched_tokens
        # Blocks to swap in. Dict of CPU -> GPU block number.
        self.blocks_to_swap_in: Dict[int, int] = blocks_to_swap_in
        # Blocks to swap out. Dict of GPU -> CPU block number.
        self.blocks_to_swap_out: Dict[int, int] = blocks_to_swap_out
        # Blocks to copy. Source to a list of dest blocks.
        self.blocks_to_copy: Dict[int, List[int]] = blocks_to_copy
        # Sequence groups that are going to be ignored.
        self.ignored_seq_groups: List[SequenceGroup] = ignored_seq_groups

87
88
89
        # Swap in and swap out should never happen at the same time.
        assert not (blocks_to_swap_in and blocks_to_swap_out)

90
        self.num_loras: int = len(self.lora_requests)
91
92
93
        if self.num_loras > 0:
            self._sort_by_lora_ids()

94
    def is_empty(self) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
        # NOTE: We do not consider the ignored sequence groups.
        return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
                and not self.blocks_to_swap_out and not self.blocks_to_copy)
98

99
    def _sort_by_lora_ids(self) -> bool:
100
101
102
        self.scheduled_seq_groups = sorted(
            self.scheduled_seq_groups,
            key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
103
104
105

    @property
    def lora_requests(self) -> Set[LoRARequest]:
106
        return {g.seq_group.lora_request for g in self.scheduled_seq_groups}
107

108

Woosuk Kwon's avatar
Woosuk Kwon committed
109
110
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
111
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        self,
113
114
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
115
        lora_config: Optional[LoRAConfig],
Woosuk Kwon's avatar
Woosuk Kwon committed
116
    ) -> None:
117
118
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
119
120
121
122
        # Note for LoRA scheduling: the current policy is extremely
        # simple and NOT fair. It can lead to starvation of some
        # LoRAs. This should be improved in the future.
        self.lora_config = lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
123

124
125
126
        self.prompt_limit = min(self.scheduler_config.max_model_len,
                                self.scheduler_config.max_num_batched_tokens)

127
        # Instantiate the scheduling policy.
128
        self.policy = PolicyFactory.get_policy(policy_name="fcfs")
129
130
131
132
133

        BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
            version="v2" if self.scheduler_config.
            use_v2_block_manager else "v1")

Woosuk Kwon's avatar
Woosuk Kwon committed
134
        # Create the block space manager.
135
        self.block_manager = BlockSpaceManagerImpl(
136
137
138
            block_size=self.cache_config.block_size,
            num_gpu_blocks=self.cache_config.num_gpu_blocks,
            num_cpu_blocks=self.cache_config.num_cpu_blocks,
139
140
            sliding_window=self.cache_config.sliding_window,
            enable_caching=self.cache_config.enable_prefix_caching)
141

142
        # Sequence groups in the WAITING state.
143
        self.waiting: Deque[SequenceGroup] = deque()
144
        # Sequence groups in the RUNNING state.
145
        self.running: Deque[SequenceGroup] = deque()
146
        # Sequence groups in the SWAPPED state.
147
        self.swapped: Deque[SequenceGroup] = deque()
Woosuk Kwon's avatar
Woosuk Kwon committed
148

149
150
151
152
153
154
155
        # Time at previous scheduling step
        self.prev_time = 0.0
        # Did we schedule a prompt at previous step?
        self.prev_prompt = False
        # Latency of the last prompt step
        self.last_prompt_latency = 0.0

156
157
158
159
    @property
    def lora_enabled(self) -> bool:
        return bool(self.lora_config)

160
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
161
        # Add sequence groups to the waiting queue.
162
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
163

Antoni Baum's avatar
Antoni Baum committed
164
    def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
165
166
167
168
169
170
171
172
173
174
175
176
        """Aborts a sequence group with the given ID.

        Check if the sequence group with the given ID
            is present in any of the state queue.
        If present, remove the sequence group from the state queue.
            Also, if any of the sequences in the sequence group is not finished,
                free the sequence with status `FINISHED_ABORTED`.
        Otherwise, do nothing.

        Args:
            request_id: The ID(s) of the sequence group to abort.
        """
Antoni Baum's avatar
Antoni Baum committed
177
178
179
        if isinstance(request_id, str):
            request_id = (request_id, )
        request_ids = set(request_id)
180
        for state_queue in [self.waiting, self.running, self.swapped]:
ljss's avatar
ljss committed
181
            aborted_groups: List[SequenceGroup] = []
182
183
184
185
186
            for seq_group in state_queue:
                if not request_ids:
                    # Using 'break' here may add two extra iterations,
                    # but is acceptable to reduce complexity .
                    break
Antoni Baum's avatar
Antoni Baum committed
187
                if seq_group.request_id in request_ids:
188
189
                    # Appending aborted group into pending list.
                    aborted_groups.append(seq_group)
Antoni Baum's avatar
Antoni Baum committed
190
                    request_ids.remove(seq_group.request_id)
191
192
193
            for aborted_group in aborted_groups:
                # Remove the sequence group from the state queue.
                state_queue.remove(aborted_group)
ljss's avatar
ljss committed
194
                for seq in aborted_group.get_seqs():
195
196
197
198
                    if seq.is_finished():
                        continue
                    seq.status = SequenceStatus.FINISHED_ABORTED
                    self.free_seq(seq)
199

200
201
202
    def has_unfinished_seqs(self) -> bool:
        return self.waiting or self.running or self.swapped

203
204
205
    def get_num_unfinished_seq_groups(self) -> int:
        return len(self.waiting) + len(self.running) + len(self.swapped)

Woosuk Kwon's avatar
Woosuk Kwon committed
206
    def _schedule(self) -> SchedulerOutputs:
207
        # Blocks that need to be swapped or copied before model execution.
208
209
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
210
        blocks_to_copy: Dict[int, List[int]] = {}
211

212
        # Fix the current time.
213
        now = time.time()
214

Woosuk Kwon's avatar
Woosuk Kwon committed
215
216
217
218
        # Join waiting sequences if possible.
        if not self.swapped:
            ignored_seq_groups: List[SequenceGroup] = []
            scheduled: List[SequenceGroup] = []
219
220
221
222
            # The total number of sequences on the fly, including the
            # requests in the generation phase.
            num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
                                for seq_group in self.running)
223
224
225
            curr_loras = set(
                seq_group.lora_int_id
                for seq_group in self.running) if self.lora_enabled else None
226

Woosuk Kwon's avatar
Woosuk Kwon committed
227
228
229
            # Optimization: We do not sort the waiting queue since the preempted
            # sequence groups are added to the front and the new sequence groups
            # are added to the back.
230
            leftover_waiting_sequences = deque()
231
            num_batched_tokens = 0
232
            while self._passed_delay(now) and self.waiting:
Woosuk Kwon's avatar
Woosuk Kwon committed
233
                seq_group = self.waiting[0]
234
235
236
                waiting_seqs = seq_group.get_seqs(
                    status=SequenceStatus.WAITING)
                assert len(waiting_seqs) == 1, (
237
238
                    "Waiting sequence group should have only one prompt "
                    "sequence.")
239
240
241
242
                # get_len includes output tokens if the request has been
                # preempted.
                num_prefill_tokens = waiting_seqs[0].get_len()
                if num_prefill_tokens > self.prompt_limit:
Woosuk Kwon's avatar
Woosuk Kwon committed
243
                    logger.warning(
244
245
                        f"Input prompt ({num_prefill_tokens} tokens) is too "
                        f"long and exceeds limit of {self.prompt_limit}")
246
                    for seq in waiting_seqs:
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
                        seq.status = SequenceStatus.FINISHED_IGNORED
                    ignored_seq_groups.append(seq_group)
249
                    self.waiting.popleft()
250
                    continue
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252

                # If the sequence group cannot be allocated, stop.
253
254
                can_allocate = self.block_manager.can_allocate(seq_group)
                if can_allocate == AllocStatus.LATER:
Woosuk Kwon's avatar
Woosuk Kwon committed
255
                    break
256
257
                elif can_allocate == AllocStatus.NEVER:
                    logger.warning(
258
259
                        f"Input prompt ({num_prefill_tokens} tokens) is too "
                        f"long and exceeds the capacity of block_manager")
260
                    for seq in waiting_seqs:
261
262
                        seq.status = SequenceStatus.FINISHED_IGNORED
                    ignored_seq_groups.append(seq_group)
263
                    self.waiting.popleft()
264
                    continue
Woosuk Kwon's avatar
Woosuk Kwon committed
265

266
267
268
                lora_int_id = 0
                if self.lora_enabled:
                    lora_int_id = seq_group.lora_int_id
269
270
                    if (lora_int_id > 0 and lora_int_id not in curr_loras
                            and len(curr_loras) >= self.lora_config.max_loras):
271
272
273
274
275
276
                        # We don't have a space for another LoRA, so
                        # we ignore this request for now.
                        leftover_waiting_sequences.appendleft(seq_group)
                        self.waiting.popleft()
                        continue

Woosuk Kwon's avatar
Woosuk Kwon committed
277
                # If the number of batched tokens exceeds the limit, stop.
278
                num_batched_tokens += num_prefill_tokens
279
                if (num_batched_tokens >
Woosuk Kwon's avatar
Woosuk Kwon committed
280
281
282
283
284
                        self.scheduler_config.max_num_batched_tokens):
                    break

                # The total number of sequences in the RUNNING state should not
                # exceed the maximum number of sequences.
285
                num_new_seqs = seq_group.get_max_num_running_seqs()
Woosuk Kwon's avatar
Woosuk Kwon committed
286
287
288
289
                if (num_curr_seqs + num_new_seqs >
                        self.scheduler_config.max_num_seqs):
                    break

290
291
292
                if lora_int_id > 0:
                    curr_loras.add(lora_int_id)
                self.waiting.popleft()
Woosuk Kwon's avatar
Woosuk Kwon committed
293
294
                self._allocate(seq_group)
                self.running.append(seq_group)
295
                num_curr_seqs += num_new_seqs
296
297
298
299
                scheduled.append(
                    ScheduledSequenceGroup(
                        seq_group=seq_group,
                        token_chunk_size=num_prefill_tokens))
300
301
            self.waiting.extendleft(leftover_waiting_sequences)

302
            if scheduled or ignored_seq_groups:
303
                self.prev_prompt = True
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
306
                scheduler_outputs = SchedulerOutputs(
                    scheduled_seq_groups=scheduled,
                    prompt_run=True,
307
                    num_batched_tokens=num_batched_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
308
309
310
311
312
313
314
315
316
                    blocks_to_swap_in=blocks_to_swap_in,
                    blocks_to_swap_out=blocks_to_swap_out,
                    blocks_to_copy=blocks_to_copy,
                    ignored_seq_groups=ignored_seq_groups,
                )
                return scheduler_outputs

        # NOTE(woosuk): Preemption happens only when there is no available slot
        # to keep all the sequence groups in the RUNNING state.
317
318
319
320
321
        # In this case, the policy is responsible for deciding which sequence
        # groups to preempt.
        self.running = self.policy.sort_by_priority(now, self.running)

        # Reserve new token slots for the running sequence groups.
322
        running: Deque[SequenceGroup] = deque()
323
324
        preempted: List[SequenceGroup] = []
        while self.running:
325
            seq_group = self.running.popleft()
326
            while not self.block_manager.can_append_slot(seq_group):
327
328
                if self.running:
                    # Preempt the lowest-priority sequence groups.
329
                    victim_seq_group = self.running.pop()
330
331
332
333
334
335
336
                    self._preempt(victim_seq_group, blocks_to_swap_out)
                    preempted.append(victim_seq_group)
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
                    self._preempt(seq_group, blocks_to_swap_out)
                    preempted.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
337
338
                    break
            else:
339
                # Append new slots to the sequence group.
340
                self._append_slot(seq_group, blocks_to_copy)
341
342
343
344
345
                running.append(seq_group)
        self.running = running

        # Swap in the sequence groups in the SWAPPED state if possible.
        self.swapped = self.policy.sort_by_priority(now, self.swapped)
346
347
348
        if not preempted:
            num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
                                for seq_group in self.running)
349
350
351
352
353
            curr_loras = set(
                seq_group.lora_int_id
                for seq_group in self.running) if self.lora_enabled else None

            leftover_swapped = deque()
354
355
356

            while self.swapped:
                seq_group = self.swapped[0]
357
358
359
                lora_int_id = 0
                if self.lora_enabled:
                    lora_int_id = seq_group.lora_int_id
360
361
                    if (lora_int_id > 0 and lora_int_id not in curr_loras
                            and len(curr_loras) >= self.lora_config.max_loras):
362
363
364
365
366
367
                        # We don't have a space for another LoRA, so
                        # we ignore this request for now.
                        leftover_swapped.appendleft(seq_group)
                        self.swapped.popleft()
                        continue

368
369
370
                # If the sequence group cannot be swapped in, stop.
                if not self.block_manager.can_swap_in(seq_group):
                    break
371

372
373
374
375
376
377
378
                # The total number of sequences in the RUNNING state should not
                # exceed the maximum number of sequences.
                num_new_seqs = seq_group.get_max_num_running_seqs()
                if (num_curr_seqs + num_new_seqs >
                        self.scheduler_config.max_num_seqs):
                    break

379
380
381
                if lora_int_id > 0:
                    curr_loras.add(lora_int_id)
                self.swapped.popleft()
382
383
384
385
386
                self._swap_in(seq_group, blocks_to_swap_in)
                self._append_slot(seq_group, blocks_to_copy)
                num_curr_seqs += num_new_seqs
                self.running.append(seq_group)

387
388
            self.swapped.extendleft(leftover_swapped)

389
390
391
        # Each sequence in the generation phase only takes one token slot.
        # Therefore, the number of batched tokens is equal to the number of
        # sequences in the RUNNING state.
392
393
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
394
            for seq_group in self.running)
395

396
        scheduler_outputs = SchedulerOutputs(
397
398
399
400
401
            scheduled_seq_groups=[
                ScheduledSequenceGroup(seq_group=running_group,
                                       token_chunk_size=1)
                for running_group in self.running
            ],
Woosuk Kwon's avatar
Woosuk Kwon committed
402
403
            prompt_run=False,
            num_batched_tokens=num_batched_tokens,
404
405
406
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
Woosuk Kwon's avatar
Woosuk Kwon committed
407
            ignored_seq_groups=[],
408
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
409
        return scheduler_outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
410

Woosuk Kwon's avatar
Woosuk Kwon committed
411
    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
412
413
414
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
Woosuk Kwon's avatar
Woosuk Kwon committed
415
        scheduler_outputs = self._schedule()
416
        now = time.time()
417
418

        # Create input data structures.
419
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
420
421
422
        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
            token_chunk_size = scheduled_seq_group.token_chunk_size
423
424
            seq_group.maybe_set_first_scheduled_time(now)

425
            # seq_id -> SequenceData
Light Lin's avatar
Light Lin committed
426
            seq_data: Dict[int, SequenceData] = {}
427
            # seq_id -> physical block numbers
428
            block_tables: Dict[int, List[int]] = {}
429

430
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
431
                seq_id = seq.seq_id
432
                seq_data[seq_id] = seq.data
433
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
434
                self.block_manager.access_all_blocks_in_seq(seq, now)
435

436
437
438
439
            common_computed_block_nums = (
                self.block_manager.get_common_computed_block_ids(
                    seq_group.get_seqs(status=SequenceStatus.RUNNING)))

440
            seq_group_metadata = SequenceGroupMetadata(
441
                request_id=seq_group.request_id,
Woosuk Kwon's avatar
Woosuk Kwon committed
442
                is_prompt=scheduler_outputs.prompt_run,
443
                seq_data=seq_data,
444
                sampling_params=seq_group.sampling_params,
445
                block_tables=block_tables,
446
                token_chunk_size=token_chunk_size,
447
                lora_request=seq_group.lora_request,
448
                computed_block_nums=common_computed_block_nums,
Nick Hill's avatar
Nick Hill committed
449
                state=seq_group.state,
450
451
452
453
454
455
                # `multi_modal_data` will only be present for the 1st comm
                # between engine and worker.
                # the subsequent comms can still use delta, but
                # `multi_modal_data` will be None.
                multi_modal_data=seq_group.multi_modal_data
                if scheduler_outputs.prompt_run else None,
456
            )
457
            seq_group_metadata_list.append(seq_group_metadata)
458
459
460
461
462

        # Now that the batch has been created, we can assume all blocks in the
        # batch will have been computed before the next scheduling invocation.
        # This is because the engine assumes that a failure in model execution
        # will crash the vLLM instance / will not retry.
463
464
465
        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            self.block_manager.mark_blocks_as_computed(
                scheduled_seq_group.seq_group)
466

Woosuk Kwon's avatar
Woosuk Kwon committed
467
        return seq_group_metadata_list, scheduler_outputs
468

469
470
    def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
        self.block_manager.fork(parent_seq, child_seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
471

472
    def free_seq(self, seq: Sequence) -> None:
473
        """Free a sequence from a block table."""
474
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
475

476
    def free_finished_seq_groups(self) -> None:
477
478
        self.running = deque(seq_group for seq_group in self.running
                             if not seq_group.is_finished())
Woosuk Kwon's avatar
Woosuk Kwon committed
479

480
481
    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
482
        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
483
484
            seq.status = SequenceStatus.RUNNING

485
    def _append_slot(
486
487
488
489
490
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, List[int]],
    ) -> None:
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
491
            ret = self.block_manager.append_slot(seq)
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
            if ret is not None:
                src_block, dst_block = ret
                if src_block in blocks_to_copy:
                    blocks_to_copy[src_block].append(dst_block)
                else:
                    blocks_to_copy[src_block] = [dst_block]

    def _preempt(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
        preemption_mode: Optional[PreemptionMode] = None,
    ) -> None:
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
508
509
        # (e.g., beam search), recomputation is not currently supported. In
        # such a case, we use swapping instead.
510
511
512
513
514
515
516
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
        if preemption_mode is None:
517
            if seq_group.get_max_num_running_seqs() == 1:
518
519
520
521
522
523
524
525
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
526
            raise AssertionError("Invalid preemption mode.")
527
528
529
530
531
532
533
534
535

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
536
537
            self.free_seq(seq)
            seq.reset_state_for_recompute()
538
539
        # NOTE: For FCFS, we insert the preempted sequence group to the front
        # of the waiting queue.
540
        self.waiting.appendleft(seq_group)
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        self._swap_out(seq_group, blocks_to_swap_out)
        self.swapped.append(seq_group)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
        blocks_to_swap_in.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
565
566
567
568
569
570
        if not self.block_manager.can_swap_out(seq_group):
            # FIXME(woosuk): Abort the sequence group instead of aborting the
            # entire engine.
            raise RuntimeError(
                "Aborted due to the lack of CPU swap space. Please increase "
                "the swap space to avoid this error.")
571
572
573
574
        mapping = self.block_manager.swap_out(seq_group)
        blocks_to_swap_out.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED
575

576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
    def _passed_delay(self, now: float) -> bool:
        if self.prev_prompt:
            self.last_prompt_latency = now - self.prev_time
        self.prev_time, self.prev_prompt = now, False
        # Delay scheduling prompts to let waiting queue fill up
        if self.scheduler_config.delay_factor > 0 and self.waiting:
            earliest_arrival_time = min(
                [e.metrics.arrival_time for e in self.waiting])
            passed_delay = (
                (now - earliest_arrival_time) >
                (self.scheduler_config.delay_factor * self.last_prompt_latency)
                or not self.running)
        else:
            passed_delay = True
        return passed_delay