scheduler.py 51.3 KB
Newer Older
1
import enum
2
3
import os
import random
4
import time
5
from collections import deque
6
from dataclasses import dataclass, field
7
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
10
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.logger import init_logger
12
from vllm.lora.request import LoRARequest
13
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
15
                           SequenceGroupMetadata, SequenceStatus)
Woosuk Kwon's avatar
Woosuk Kwon committed
16

Woosuk Kwon's avatar
Woosuk Kwon committed
17
logger = init_logger(__name__)
18

19
20
21
22
23
24
25
# Test-only. If configured, decode is preempted with
# ARTIFICIAL_PREEMPTION_PROB% probability.
ENABLE_ARTIFICIAL_PREEMPT = bool(
    os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False))  # noqa
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500

Woosuk Kwon's avatar
Woosuk Kwon committed
26

27
28
29
30
31
32
33
34
35
36
37
38
39
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


40
41
@dataclass
class SchedulingBudget:
42
43
44
45
46
47
48
49
50
    """The available slots for scheduling.

    TODO(sang): Right now, the budget is request_id-aware meaning it can ignore
    budget update from the same request_id. It is because in normal scheduling
    path, we update RUNNING num_seqs ahead of time, meaning it could be
    updated more than once when scheduling RUNNING requests. Since this won't
    happen if we only have chunked prefill scheduling, we can remove this
    feature from the API when chunked prefill is enabled by default.
    """
51
52
    token_budget: int
    max_num_seqs: int
53
54
    _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
    _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
55
56
    _num_batched_tokens: int = 0
    _num_curr_seqs: int = 0
57
58

    def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
59
60
        assert num_new_tokens != 0
        assert num_new_seqs != 0
61
62
63
        return (self.num_batched_tokens + num_new_tokens <= self.token_budget
                and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)

64
65
66
67
    def remaining_token_budget(self):
        return self.token_budget - self.num_batched_tokens

    def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
68
        if req_id in self._request_ids_num_batched_tokens:
69
70
            return

71
        self._request_ids_num_batched_tokens.add(req_id)
72
73
74
75
        self._num_batched_tokens += num_batched_tokens

    def subtract_num_batched_tokens(self, req_id: str,
                                    num_batched_tokens: int):
76
77
        if req_id in self._request_ids_num_batched_tokens:
            self._request_ids_num_batched_tokens.remove(req_id)
78
79
80
            self._num_batched_tokens -= num_batched_tokens

    def add_num_seqs(self, req_id: str, num_curr_seqs: int):
81
        if req_id in self._request_ids_num_curr_seqs:
82
83
            return

84
        self._request_ids_num_curr_seqs.add(req_id)
85
86
87
        self._num_curr_seqs += num_curr_seqs

    def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
88
89
        if req_id in self._request_ids_num_curr_seqs:
            self._request_ids_num_curr_seqs.remove(req_id)
90
91
92
93
94
95
96
97
98
99
            self._num_curr_seqs -= num_curr_seqs

    @property
    def num_batched_tokens(self):
        return self._num_batched_tokens

    @property
    def num_curr_seqs(self):
        return self._num_curr_seqs

100

101
102
103
104
105
106
107
108
109
110
@dataclass
class ScheduledSequenceGroup:
    # A sequence group that's scheduled.
    seq_group: SequenceGroup
    # The total chunk size (number of tokens) to process for next iteration.
    # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
    # chunked, it can be smaller than that.
    token_chunk_size: int


111
@dataclass
112
class SchedulerOutputs:
113
    """The scheduling decision made from a scheduler."""
114
115
116
117
118
119
    # Scheduled sequence groups.
    scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
    # Number of prefill groups scheduled.
    num_prefill_groups: int
    # Total number of batched tokens.
    num_batched_tokens: int
120
121
122
123
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]]
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]]
124
125
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]]
126
127
128
129
    # Sequence groups that are going to be ignored.
    ignored_seq_groups: List[SequenceGroup]
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int
130
131
    # The number of requests in the running queue
    running_queue_size: int
132
    preempted: int
133
134

    def __post_init__(self):
135
        # Swap in and swap out should never happen at the same time.
136
        assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
137

138
        self.num_loras: int = len(self.lora_requests)
139
140
141
        if self.num_loras > 0:
            self._sort_by_lora_ids()

142
143
        self.num_prompt_adapters: int = len(self.prompt_adapter_requests)

144
    def is_empty(self) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147
        # NOTE: We do not consider the ignored sequence groups.
        return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
                and not self.blocks_to_swap_out and not self.blocks_to_copy)
148

149
    def _sort_by_lora_ids(self):
150
151
152
        self.scheduled_seq_groups = sorted(
            self.scheduled_seq_groups,
            key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
153
154
155

    @property
    def lora_requests(self) -> Set[LoRARequest]:
156
157
158
159
160
        return {
            g.seq_group.lora_request
            for g in self.scheduled_seq_groups
            if g.seq_group.lora_request is not None
        }
161

162
163
164
165
166
167
168
169
    @property
    def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
        return {
            g.seq_group.prompt_adapter_request
            for g in self.scheduled_seq_groups
            if g.seq_group.prompt_adapter_request is not None
        }

170

171
@dataclass
172
173
174
175
176
177
178
179
180
181
182
class SchedulerRunningOutputs:
    """The requests that are scheduled from a running queue.

    Could contain prefill (prefill that's chunked) or decodes. If there's not
    enough memory, it can be preempted (for recompute) or swapped out.
    """
    # Selected sequences that are running and in a decoding phase.
    decode_seq_groups: List[SequenceGroup]
    # Selected sequences that are running and in a prefill phase.
    # I.e., it means the prefill has been chunked.
    prefill_seq_groups: List[SequenceGroup]
183
184
185
186
187
    # The preempted sequences.
    preempted: List[SequenceGroup]
    # Sequences that are swapped out.
    swapped_out: List[SequenceGroup]
    # The blocks to swap out.
188
    blocks_to_swap_out: List[Tuple[int, int]]
189
    # The blocks to copy.
190
    blocks_to_copy: List[Tuple[int, int]]
191
    # The number of slots for lookahead decoding.
192
193
194
    num_lookahead_slots: int

    @classmethod
195
196
197
198
    def create_empty(cls) -> "SchedulerRunningOutputs":
        return SchedulerRunningOutputs(
            decode_seq_groups=[],
            prefill_seq_groups=[],
199
200
            preempted=[],
            swapped_out=[],
201
            blocks_to_swap_out=[],
202
            blocks_to_copy=[],
203
204
205
206
207
208
            num_lookahead_slots=0,
        )


@dataclass
class SchedulerSwappedInOutputs:
209
210
211
212
213
214
215
216
217
218
    """The requests that are scheduled from a swap queue.

    Could contain prefill (prefill that's chunked) or decodes.
    """
    # Selected sequences that are going to be swapped in and is in a
    # decoding phase.
    decode_seq_groups: List[SequenceGroup]
    # Selected sequences that are going to be swapped in and in a prefill
    # phase. I.e., it means the prefill has been chunked.
    prefill_seq_groups: List[SequenceGroup]
219
    # The blocks to swap in.
220
    blocks_to_swap_in: List[Tuple[int, int]]
221
    # The blocks to copy.
222
    blocks_to_copy: List[Tuple[int, int]]
223
    # The number of slots for lookahead decoding.
224
    num_lookahead_slots: int
225
226
    # Infeasible sequence groups.
    infeasible_seq_groups: List[SequenceGroup]
227
228
229
230

    @classmethod
    def create_empty(cls) -> "SchedulerSwappedInOutputs":
        return SchedulerSwappedInOutputs(
231
232
            decode_seq_groups=[],
            prefill_seq_groups=[],
233
            blocks_to_swap_in=[],
234
            blocks_to_copy=[],
235
            num_lookahead_slots=0,
236
            infeasible_seq_groups=[],
237
238
239
240
241
        )


@dataclass
class SchedulerPrefillOutputs:
242
243
244
245
246
247
    """The requests that are scheduled from a waiting queue.

    Could contain a fresh prefill requests or preempted requests that need
    to be recomputed from scratch.
    """
    # Selected sequences for prefill.
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    seq_groups: List[SequenceGroup]
    # Ignored sequence groups.
    ignored_seq_groups: List[SequenceGroup]
    num_lookahead_slots: int

    @classmethod
    def create_empty(cls) -> "SchedulerPrefillOutputs":
        return SchedulerPrefillOutputs(
            seq_groups=[],
            ignored_seq_groups=[],
            num_lookahead_slots=0,
        )


Woosuk Kwon's avatar
Woosuk Kwon committed
262
263
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
264
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
265
        self,
266
267
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
268
        lora_config: Optional[LoRAConfig],
269
        pipeline_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
270
    ) -> None:
271
272
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
273
274
275
276
        # Note for LoRA scheduling: the current policy is extremely
        # simple and NOT fair. It can lead to starvation of some
        # LoRAs. This should be improved in the future.
        self.lora_config = lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
277

278
279
280
281
282
283
        version = "v1"
        if self.scheduler_config.use_v2_block_manager:
            version = "v2"
        if self.scheduler_config.embedding_mode:
            version = "embedding"

284
        BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
285
            version)
286

287
288
289
290
291
292
293
294
        num_gpu_blocks = cache_config.num_gpu_blocks
        if num_gpu_blocks:
            num_gpu_blocks //= pipeline_parallel_size

        num_cpu_blocks = cache_config.num_cpu_blocks
        if num_cpu_blocks:
            num_cpu_blocks //= pipeline_parallel_size

Woosuk Kwon's avatar
Woosuk Kwon committed
295
        # Create the block space manager.
296
        self.block_manager = BlockSpaceManagerImpl(
297
            block_size=self.cache_config.block_size,
298
299
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
300
301
            sliding_window=self.cache_config.sliding_window,
            enable_caching=self.cache_config.enable_prefix_caching)
302

303
        # Sequence groups in the WAITING state.
304
        # Contain new prefill or preempted requests.
305
        self.waiting: Deque[SequenceGroup] = deque()
306
        # Sequence groups in the RUNNING state.
307
        # Contain decode requests.
308
        self.running: Deque[SequenceGroup] = deque()
309
        # Sequence groups in the SWAPPED state.
310
        # Contain decode requests that are swapped out.
311
        self.swapped: Deque[SequenceGroup] = deque()
Mor Zusman's avatar
Mor Zusman committed
312
313
314
        # Sequence groups finished requests ids since last step iteration.
        # It lets the model know that any state associated with these requests
        # can and must be released after the current step.
315
        # This is used to evict the finished requests from the Mamba cache.
Mor Zusman's avatar
Mor Zusman committed
316
        self._finished_requests_ids: List[str] = list()
317
318
319
320
321
322
        # Time at previous scheduling step
        self.prev_time = 0.0
        # Did we schedule a prompt at previous step?
        self.prev_prompt = False
        # Latency of the last prompt step
        self.last_prompt_latency = 0.0
323
324
        # preemption mode, RECOMPUTE or SWAP
        self.user_specified_preemption_mode = scheduler_config.preemption_mode
325

326
327
328
329
330
331
        # The following field is test-only. It is used to inject artificial
        # preemption.
        self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
        self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
                                       if self.enable_artificial_preemption
                                       else 0)
332
        self.num_cumulative_preemption: int = 0
333

334
335
336
337
    @property
    def lora_enabled(self) -> bool:
        return bool(self.lora_config)

338
339
340
341
342
    @property
    def num_decoding_tokens_per_seq(self) -> int:
        """The number of new tokens."""
        return 1

343
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
344
        # Add sequence groups to the waiting queue.
345
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
346

347
348
349
350
351
352
353
354
355
356
    def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
        # Add sequence groups to the running queue.
        # Only for testing purposes.
        self.running.append(seq_group)

    def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
        # Add sequence groups to the swapped queue.
        # Only for testing purposes.
        self.swapped.append(seq_group)

Antoni Baum's avatar
Antoni Baum committed
357
    def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
358
359
360
361
362
363
364
365
366
367
368
369
        """Aborts a sequence group with the given ID.

        Check if the sequence group with the given ID
            is present in any of the state queue.
        If present, remove the sequence group from the state queue.
            Also, if any of the sequences in the sequence group is not finished,
                free the sequence with status `FINISHED_ABORTED`.
        Otherwise, do nothing.

        Args:
            request_id: The ID(s) of the sequence group to abort.
        """
Antoni Baum's avatar
Antoni Baum committed
370
371
372
        if isinstance(request_id, str):
            request_id = (request_id, )
        request_ids = set(request_id)
373
        for state_queue in [self.waiting, self.running, self.swapped]:
ljss's avatar
ljss committed
374
            aborted_groups: List[SequenceGroup] = []
375
376
377
            for seq_group in state_queue:
                if not request_ids:
                    # Using 'break' here may add two extra iterations,
378
                    # but is acceptable to reduce complexity.
379
                    break
Antoni Baum's avatar
Antoni Baum committed
380
                if seq_group.request_id in request_ids:
381
382
                    # Appending aborted group into pending list.
                    aborted_groups.append(seq_group)
Antoni Baum's avatar
Antoni Baum committed
383
                    request_ids.remove(seq_group.request_id)
384
385
386
            for aborted_group in aborted_groups:
                # Remove the sequence group from the state queue.
                state_queue.remove(aborted_group)
387
                # Remove the aborted request from the Mamba cache.
388
                self._finished_requests_ids.append(aborted_group.request_id)
ljss's avatar
ljss committed
389
                for seq in aborted_group.get_seqs():
390
391
392
393
                    if seq.is_finished():
                        continue
                    seq.status = SequenceStatus.FINISHED_ABORTED
                    self.free_seq(seq)
394

395
    def has_unfinished_seqs(self) -> bool:
396
397
        return len(self.waiting) != 0 or len(self.running) != 0 or len(
            self.swapped) != 0
398

399
400
401
    def get_num_unfinished_seq_groups(self) -> int:
        return len(self.waiting) + len(self.running) + len(self.swapped)

Mor Zusman's avatar
Mor Zusman committed
402
403
404
405
406
407
    def get_and_reset_finished_requests_ids(self) -> List[str]:
        """Flushes the list of request ids of previously finished seq_groups."""
        finished_requests_ids = self._finished_requests_ids
        self._finished_requests_ids = list()
        return finished_requests_ids

408
    def _schedule_running(
409
410
411
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
412
        enable_chunking: bool = False,
413
    ) -> SchedulerRunningOutputs:
414
        """Schedule sequence groups that are running.
415

416
        Running queue should include decode and chunked prefill requests.
Woosuk Kwon's avatar
Woosuk Kwon committed
417

418
419
420
421
422
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any decodes are preempted.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any decodes are preempted.
423
424
425
426
427
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
    
428
        Returns:
429
            SchedulerRunningOutputs.
430
431
        """
        # Blocks that need to be swapped or copied before model execution.
432
        blocks_to_swap_out: List[Tuple[int, int]] = []
433
        blocks_to_copy: List[Tuple[int, int]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
434

435
436
        decode_seq_groups: List[ScheduledSequenceGroup] = []
        prefill_seq_groups: List[ScheduledSequenceGroup] = []
437
438
        preempted: List[SequenceGroup] = []
        swapped_out: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
439
440
441

        # NOTE(woosuk): Preemption happens only when there is no available slot
        # to keep all the sequence groups in the RUNNING state.
442
443
444

        running_queue = self.running

445
446
        while running_queue:
            seq_group = running_queue[0]
447
448
449
            num_running_tokens = self._get_num_new_tokens(
                seq_group, SequenceStatus.RUNNING, enable_chunking, budget)

450
451
            if num_running_tokens == 0:
                break
452
453

            running_queue.popleft()
454
            while not self._can_append_slots(seq_group):
455
456
                budget.subtract_num_batched_tokens(seq_group.request_id,
                                                   num_running_tokens)
457
                num_running_seqs = seq_group.get_max_num_running_seqs()
458
459
                budget.subtract_num_seqs(seq_group.request_id,
                                         num_running_seqs)
460
461
462

                if (curr_loras is not None and seq_group.lora_int_id > 0
                        and seq_group.lora_int_id in curr_loras):
463
                    curr_loras.remove(seq_group.lora_int_id)
464
465

                if running_queue:
466
                    # Preempt the lowest-priority sequence groups.
467
468
469
470
471
472
473
                    victim_seq_group = running_queue.pop()
                    preempted_mode = self._preempt(victim_seq_group,
                                                   blocks_to_swap_out)
                    if preempted_mode == PreemptionMode.RECOMPUTE:
                        preempted.append(victim_seq_group)
                    else:
                        swapped_out.append(victim_seq_group)
474
475
476
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
477
478
479
480
481
482
                    preempted_mode = self._preempt(seq_group,
                                                   blocks_to_swap_out)
                    if preempted_mode == PreemptionMode.RECOMPUTE:
                        preempted.append(seq_group)
                    else:
                        swapped_out.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
483
484
                    break
            else:
485
                self._append_slots(seq_group, blocks_to_copy)
486
487
488
489
490
491
492
493
494
495
496
497
                is_prefill = seq_group.is_prefill()
                if is_prefill:
                    prefill_seq_groups.append(
                        ScheduledSequenceGroup(
                            seq_group=seq_group,
                            token_chunk_size=num_running_tokens))
                else:
                    decode_seq_groups.append(
                        ScheduledSequenceGroup(seq_group=seq_group,
                                               token_chunk_size=1))
                budget.add_num_batched_tokens(seq_group.request_id,
                                              num_running_tokens)
498
499
500
501
502
503
504
                # OPTIMIZATION:  Note that get_max_num_running_seqs is
                # expensive. For the default scheduling chase where
                # enable_chunking is False, num_seqs are updated before running
                # this method, so we don't have to update it again here.
                if enable_chunking:
                    num_running_seqs = seq_group.get_max_num_running_seqs()
                    budget.add_num_seqs(seq_group.request_id, num_running_seqs)
505
506
507
                if curr_loras is not None and seq_group.lora_int_id > 0:
                    curr_loras.add(seq_group.lora_int_id)

508
        return SchedulerRunningOutputs(
509
510
            decode_seq_groups=decode_seq_groups,
            prefill_seq_groups=prefill_seq_groups,
511
512
513
514
515
516
            preempted=preempted,
            swapped_out=swapped_out,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
            num_lookahead_slots=self._get_num_lookahead_slots(
                is_prefill=False))
517

518
519
520
521
    def _schedule_swapped(
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
522
        enable_chunking: bool = False,
523
    ) -> SchedulerSwappedInOutputs:
524
        """Schedule sequence groups that are swapped out.
525

526
527
528
        It schedules swapped requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.
529

530
531
532
533
534
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are swapped in.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are swapped in.
535
536
537
538
539
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.

540
541
542
543
        Returns:
            SchedulerSwappedInOutputs.
        """
        # Blocks that need to be swapped or copied before model execution.
544
        blocks_to_swap_in: List[Tuple[int, int]] = []
545
        blocks_to_copy: List[Tuple[int, int]] = []
546
547
        decode_seq_groups: List[ScheduledSequenceGroup] = []
        prefill_seq_groups: List[ScheduledSequenceGroup] = []
548
        infeasible_seq_groups: List[SequenceGroup] = []
549

550
551
        swapped_queue = self.swapped

552
        leftover_swapped: Deque[SequenceGroup] = deque()
553
554
555
556
        while swapped_queue:
            seq_group = swapped_queue[0]

            # If the sequence group cannot be swapped in, stop.
557
558
559
            is_prefill = seq_group.is_prefill()
            alloc_status = self.block_manager.can_swap_in(
                seq_group, self._get_num_lookahead_slots(is_prefill))
560
            if alloc_status == AllocStatus.LATER:
561
                break
562
563
564
565
566
567
568
569
570
571
            elif alloc_status == AllocStatus.NEVER:
                logger.warning(
                    "Failing the request %s because there's not enough kv "
                    "cache blocks to run the entire sequence.",
                    seq_group.request_id)
                for seq in seq_group.get_seqs():
                    seq.status = SequenceStatus.FINISHED_IGNORED
                infeasible_seq_groups.append(seq_group)
                swapped_queue.popleft()
                continue
572
573
574
575

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
576
577
578
                assert curr_loras is not None
                assert self.lora_config is not None
                if (lora_int_id > 0 and (lora_int_id not in curr_loras)
579
580
581
582
583
584
585
586
587
588
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_swapped.appendleft(seq_group)
                    swapped_queue.popleft()
                    continue

            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
            num_new_seqs = seq_group.get_max_num_running_seqs()
589
590
591
            num_new_tokens = self._get_num_new_tokens(seq_group,
                                                      SequenceStatus.SWAPPED,
                                                      enable_chunking, budget)
592

593
594
595
            if (num_new_tokens == 0
                    or not budget.can_schedule(num_new_tokens=num_new_tokens,
                                               num_new_seqs=num_new_seqs)):
596
597
598
599
600
601
602
                break

            if lora_int_id > 0 and curr_loras is not None:
                curr_loras.add(lora_int_id)
            swapped_queue.popleft()
            self._swap_in(seq_group, blocks_to_swap_in)
            self._append_slots(seq_group, blocks_to_copy)
603
604
605
606
607
608
609
610
611
612
            is_prefill = seq_group.is_prefill()
            if is_prefill:
                prefill_seq_groups.append(
                    ScheduledSequenceGroup(seq_group,
                                           token_chunk_size=num_new_tokens))
            else:
                decode_seq_groups.append(
                    ScheduledSequenceGroup(seq_group, token_chunk_size=1))
            budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
613
614
615

        swapped_queue.extendleft(leftover_swapped)

616
        return SchedulerSwappedInOutputs(
617
618
            decode_seq_groups=decode_seq_groups,
            prefill_seq_groups=prefill_seq_groups,
619
620
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_copy=blocks_to_copy,
621
            num_lookahead_slots=self._get_num_lookahead_slots(
622
623
624
                is_prefill=False),
            infeasible_seq_groups=infeasible_seq_groups,
        )
625

626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
    def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
        if self.scheduler_config.chunked_prefill_enabled:
            prompt_limit = self.scheduler_config.max_model_len
        else:
            prompt_limit = min(self.scheduler_config.max_model_len,
                               self.scheduler_config.max_num_batched_tokens)

        # Model is fine tuned with long context. Return the fine tuned max_len.
        if (seq_group.lora_request
                and seq_group.lora_request.long_lora_max_len):
            assert prompt_limit <= seq_group.lora_request.long_lora_max_len
            return seq_group.lora_request.long_lora_max_len
        else:
            return prompt_limit

641
642
643
644
    def _schedule_prefills(
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
645
        enable_chunking: bool = False,
646
    ) -> SchedulerPrefillOutputs:
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
        """Schedule sequence groups that are in prefill stage.

        Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
        as a new prefill (that starts from beginning -> most recently generated
        tokens).

        It schedules waiting requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.

        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are scheduled.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are scheduled.
662
663
664
665
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
666
667
668
669
670
671

        Returns:
            SchedulerSwappedInOutputs.
        """
        ignored_seq_groups: List[SequenceGroup] = []
        seq_groups: List[SequenceGroup] = []
672
673

        waiting_queue = self.waiting
674

675
        leftover_waiting_sequences: Deque[SequenceGroup] = deque()
676
677
678
679
680
681
682
        while self._passed_delay(time.time()) and waiting_queue:
            seq_group = waiting_queue[0]

            waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
            assert len(waiting_seqs) == 1, (
                "Waiting sequence group should have only one prompt "
                "sequence.")
683
684
685
686
687
688
689
            num_new_tokens = self._get_num_new_tokens(seq_group,
                                                      SequenceStatus.WAITING,
                                                      enable_chunking, budget)
            if not enable_chunking:
                num_prompt_tokens = waiting_seqs[0].get_len()
                assert num_new_tokens == num_prompt_tokens

690
691
            prompt_limit = self._get_prompt_limit(seq_group)
            if num_new_tokens > prompt_limit:
692
                logger.warning(
693
                    "Input prompt (%d tokens) is too long"
694
                    " and exceeds limit of %d", num_new_tokens, prompt_limit)
695
696
697
698
699
700
701
702
703
704
705
706
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

            # If the sequence group cannot be allocated, stop.
            can_allocate = self.block_manager.can_allocate(seq_group)
            if can_allocate == AllocStatus.LATER:
                break
            elif can_allocate == AllocStatus.NEVER:
                logger.warning(
707
708
709
                    "Input prompt (%d tokens) is too long"
                    " and exceeds the capacity of block_manager",
                    num_new_tokens)
710
711
712
713
714
715
716
717
718
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
719
720
                assert curr_loras is not None
                assert self.lora_config is not None
721
722
723
724
725
726
727
728
729
730
                if (self.lora_enabled and lora_int_id > 0
                        and lora_int_id not in curr_loras
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_waiting_sequences.appendleft(seq_group)
                    waiting_queue.popleft()
                    continue

            num_new_seqs = seq_group.get_max_num_running_seqs()
731
732
733
            if (num_new_tokens == 0
                    or not budget.can_schedule(num_new_tokens=num_new_tokens,
                                               num_new_seqs=num_new_seqs)):
734
735
736
737
738
739
                break

            # Can schedule this request.
            if curr_loras is not None and lora_int_id > 0:
                curr_loras.add(lora_int_id)
            waiting_queue.popleft()
740
            self._allocate_and_set_running(seq_group)
741
742
            seq_groups.append(
                ScheduledSequenceGroup(seq_group=seq_group,
743
744
745
                                       token_chunk_size=num_new_tokens))
            budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
746
747
748
749
750
751

        # Queue requests that couldn't be scheduled.
        waiting_queue.extendleft(leftover_waiting_sequences)
        if len(seq_groups) > 0:
            self.prev_prompt = True

752
        return SchedulerPrefillOutputs(
753
754
755
756
            seq_groups=seq_groups,
            ignored_seq_groups=ignored_seq_groups,
            num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))

757
758
    def _schedule_default(self) -> SchedulerOutputs:
        """Schedule queued requests.
759
        
760
        The current policy is designed to optimize the throughput. First,
761
762
763
764
765
766
767
768
769
        it batches as many prefill requests as possible. And it schedules
        decodes. If there's a pressure on GPU memory, decode requests can
        be swapped or preempted.
        """
        # Include running requests to the budget.
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
770
771
772
773
774
        # Make sure we include num running seqs before scheduling prefill,
        # so that we don't schedule beyond max_num_seqs for prefill.
        for seq_group in self.running:
            budget.add_num_seqs(seq_group.request_id,
                                seq_group.get_max_num_running_seqs())
775
        curr_loras = set(
776
777
            seq_group.lora_int_id for seq_group in self.running
            if seq_group.lora_int_id > 0) if self.lora_enabled else None
778

779
780
781
        prefills = SchedulerPrefillOutputs.create_empty()
        running_scheduled = SchedulerRunningOutputs.create_empty()
        swapped_in = SchedulerSwappedInOutputs.create_empty()
782
783
784

        # If any requests are swapped, prioritized swapped requests.
        if not self.swapped:
785
786
787
            prefills = self._schedule_prefills(budget,
                                               curr_loras,
                                               enable_chunking=False)
788
789

        # Don't schedule decodes if prefills are scheduled.
790
791
        # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
        # only contains decode requests, not chunked prefills.
792
        if len(prefills.seq_groups) == 0:
793
794
795
            running_scheduled = self._schedule_running(budget,
                                                       curr_loras,
                                                       enable_chunking=False)
796

797
798
            # If any sequence group is preempted, do not swap in any sequence
            # group. because it means there's no slot for new running requests.
799
800
            if len(running_scheduled.preempted) + len(
                    running_scheduled.swapped_out) == 0:
801
                swapped_in = self._schedule_swapped(budget, curr_loras)
802
803
804
805
806
807

        assert (budget.num_batched_tokens <=
                self.scheduler_config.max_num_batched_tokens)
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
808
        self.waiting.extendleft(running_scheduled.preempted)
809
810
        # Update new running requests.
        self.running.extend([s.seq_group for s in prefills.seq_groups])
811
812
813
814
        self.running.extend(
            [s.seq_group for s in running_scheduled.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.decode_seq_groups])
815
        # Update swapped requests.
816
        self.swapped.extend(running_scheduled.swapped_out)
817
818
        preempted = (len(running_scheduled.preempted) +
                     len(running_scheduled.swapped_out))
819

820
821
822
823
        # There should be no prefill from running queue because this policy
        # doesn't allow chunked prefills.
        assert len(running_scheduled.prefill_seq_groups) == 0
        assert len(swapped_in.prefill_seq_groups) == 0
824
        return SchedulerOutputs(
825
826
827
            scheduled_seq_groups=(prefills.seq_groups +
                                  running_scheduled.decode_seq_groups +
                                  swapped_in.decode_seq_groups),
828
829
830
            num_prefill_groups=len(prefills.seq_groups),
            num_batched_tokens=budget.num_batched_tokens,
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
831
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
832
833
            blocks_to_copy=running_scheduled.blocks_to_copy +
            swapped_in.blocks_to_copy,
834
835
            ignored_seq_groups=prefills.ignored_seq_groups +
            swapped_in.infeasible_seq_groups,
836
            num_lookahead_slots=running_scheduled.num_lookahead_slots,
837
            running_queue_size=len(self.running),
838
            preempted=preempted,
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
        )

    def _schedule_chunked_prefill(self):
        """Schedule queued requests.
        
        Chunked prefill allows to chunk prefill requests, batch them together
        with decode requests. This policy 1. schedule as many decoding requests
        as possible. 2. schedule chunked prefill requests that are not
        finished. 3. schedule swapped request. 4. schedule new prefill
        requests.

        The policy can sustain the high GPU utilization because it can put
        prefill and decodes requests to the same batch, while it improves
        inter token latency because decodes requests don't need to blocked
        by prefill requests.
        """
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
859
        curr_loras: Set[int] = set()
860

861
862
        prefills = SchedulerPrefillOutputs.create_empty()
        swapped_in = SchedulerSwappedInOutputs.create_empty()
863
864

        # Decoding should be always scheduled first by fcfs.
865
866
867
        running_scheduled = self._schedule_running(budget,
                                                   curr_loras,
                                                   enable_chunking=True)
868
869
870
871
872

        # Schedule swapped out requests.
        # If preemption happens, it means we don't have space for swap-in.
        if len(running_scheduled.preempted) + len(
                running_scheduled.swapped_out) == 0:
873
            swapped_in = self._schedule_swapped(budget, curr_loras)
874
875

        # Schedule new prefills.
876
877
878
        prefills = self._schedule_prefills(budget,
                                           curr_loras,
                                           enable_chunking=True)
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900

        assert (budget.num_batched_tokens <=
                self.scheduler_config.max_num_batched_tokens)
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
        self.waiting.extendleft(running_scheduled.preempted)
        # Update new running requests.
        self.running.extend([s.seq_group for s in prefills.seq_groups])
        self.running.extend(
            [s.seq_group for s in running_scheduled.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in running_scheduled.prefill_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.prefill_seq_groups])
        # Update swapped requests.
        self.swapped.extend(running_scheduled.swapped_out)
        return SchedulerOutputs(
            scheduled_seq_groups=(prefills.seq_groups +
                                  running_scheduled.prefill_seq_groups +
901
902
903
                                  swapped_in.prefill_seq_groups +
                                  running_scheduled.decode_seq_groups +
                                  swapped_in.decode_seq_groups),
904
905
906
907
908
909
            num_prefill_groups=(len(prefills.seq_groups) +
                                len(swapped_in.prefill_seq_groups) +
                                len(running_scheduled.prefill_seq_groups)),
            num_batched_tokens=budget.num_batched_tokens,
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
910
911
            blocks_to_copy=running_scheduled.blocks_to_copy +
            swapped_in.blocks_to_copy,
912
913
            ignored_seq_groups=prefills.ignored_seq_groups +
            swapped_in.infeasible_seq_groups,
914
            num_lookahead_slots=running_scheduled.num_lookahead_slots,
915
            running_queue_size=len(self.running),
916
917
            preempted=(len(running_scheduled.preempted) +
                       len(running_scheduled.swapped_out)),
918
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
919

920
921
922
923
924
925
926
    def _schedule(self) -> SchedulerOutputs:
        """Schedule queued requests."""
        if self.scheduler_config.chunked_prefill_enabled:
            return self._schedule_chunked_prefill()
        else:
            return self._schedule_default()

927
928
929
930
    def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
        """Determine whether or not we have enough space in the KV cache to
        continue generation of the sequence group.
        """
931
932
933
934
935
936
937
        # It is True only for testing case to trigger artificial preemption.
        if (self.enable_artificial_preemption
                and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
                and self.artificial_preempt_cnt > 0):
            self.artificial_preempt_cnt -= 1
            return False

938
939
940
941
942
943
944
945
        # Appending slots only occurs in decoding.
        is_prefill = False

        return self.block_manager.can_append_slots(
            seq_group=seq_group,
            num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
946
    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
947
948
949
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
Woosuk Kwon's avatar
Woosuk Kwon committed
950
        scheduler_outputs = self._schedule()
951
        now = time.time()
952
953

        # Create input data structures.
954
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
955
956
        for i, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
957
958
            seq_group = scheduled_seq_group.seq_group
            token_chunk_size = scheduled_seq_group.token_chunk_size
959
960
            seq_group.maybe_set_first_scheduled_time(now)

961
            # seq_id -> SequenceData
Light Lin's avatar
Light Lin committed
962
            seq_data: Dict[int, SequenceData] = {}
963
            # seq_id -> physical block numbers
964
            block_tables: Dict[int, List[int]] = {}
965

966
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
967
                seq_id = seq.seq_id
968
                seq_data[seq_id] = seq.data
969
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
970
                self.block_manager.access_all_blocks_in_seq(seq, now)
971

972
973
974
975
            common_computed_block_nums = (
                self.block_manager.get_common_computed_block_ids(
                    seq_group.get_seqs(status=SequenceStatus.RUNNING)))

976
977
978
979
980
981
982
983
984
985
986
987
988
989
            do_sample = True
            if seq_group.is_prefill():
                seqs = seq_group.get_seqs()
                # Prefill has only 1 sequence.
                assert len(seqs) == 1
                # In the next iteration, all prompt tokens are not computed.
                # It means the prefill is chunked, and we don't need sampling.
                # NOTE: We use get_len instead of get_prompt_len because when
                # a sequence is preempted, prefill includes previous generated
                # output tokens.
                if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
                        seqs[0].data.get_len()):
                    do_sample = False

990
991
            # It assumes the scheduled_seq_groups is ordered by
            # prefill < decoding.
992
            is_prompt = seq_group.is_prefill()
993
            seq_group_metadata = SequenceGroupMetadata(
994
                request_id=seq_group.request_id,
995
                is_prompt=is_prompt,
996
                seq_data=seq_data,
997
                sampling_params=seq_group.sampling_params,
998
                block_tables=block_tables,
999
                do_sample=do_sample,
1000
                pooling_params=seq_group.pooling_params,
1001
                token_chunk_size=token_chunk_size,
1002
                lora_request=seq_group.lora_request,
1003
                computed_block_nums=common_computed_block_nums,
1004
1005
1006
1007
1008
                # `multi_modal_data` will only be present for the 1st comm
                # between engine and worker.
                # the subsequent comms can still use delta, but
                # `multi_modal_data` will be None.
                multi_modal_data=seq_group.multi_modal_data
1009
                if scheduler_outputs.num_prefill_groups > 0 else None,
1010
                prompt_adapter_request=seq_group.prompt_adapter_request,
1011
            )
1012
            seq_group_metadata_list.append(seq_group_metadata)
1013
1014
1015
1016
1017

        # Now that the batch has been created, we can assume all blocks in the
        # batch will have been computed before the next scheduling invocation.
        # This is because the engine assumes that a failure in model execution
        # will crash the vLLM instance / will not retry.
1018
1019
1020
        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            self.block_manager.mark_blocks_as_computed(
                scheduled_seq_group.seq_group)
1021

Woosuk Kwon's avatar
Woosuk Kwon committed
1022
        return seq_group_metadata_list, scheduler_outputs
1023

1024
1025
    def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
        self.block_manager.fork(parent_seq, child_seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
1026

1027
    def free_seq(self, seq: Sequence) -> None:
1028
        """Free a sequence from a block table."""
1029
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
1030

1031
    def free_finished_seq_groups(self) -> None:
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
        remaining: Deque[SequenceGroup] = deque()
        for seq_group in self.running:
            if seq_group.is_finished():
                # Add the finished requests to the finished requests list.
                # This list will be used to update the Mamba cache in the
                # next step.
                self._finished_requests_ids.append(seq_group.request_id)
            else:
                remaining.append(seq_group)
        self.running = remaining
Woosuk Kwon's avatar
Woosuk Kwon committed
1042

1043
    def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
1044
        self.block_manager.allocate(seq_group)
1045
        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
1046
1047
            seq.status = SequenceStatus.RUNNING

1048
    def _append_slots(
1049
1050
        self,
        seq_group: SequenceGroup,
1051
        blocks_to_copy: List[Tuple[int, int]],
1052
    ) -> None:
1053
1054
1055
1056
1057
        """Appends new slots to the sequences in the given sequence group.

        Args:
            seq_group (SequenceGroup): The sequence group containing the
                sequences to append slots to.
1058
1059
1060
1061
1062
            blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
                ints, the first int is the source block index, and the second
                int is the destination block index. This list is updated with
                the new source and destination block indices for the appended
                slots.
1063
1064
1065
        """
        num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)

1066
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
1067
            cows = self.block_manager.append_slots(seq, num_lookahead_slots)
1068
            blocks_to_copy.extend(cows)
1069
1070
1071
1072

    def _preempt(
        self,
        seq_group: SequenceGroup,
1073
        blocks_to_swap_out: List[Tuple[int, int]],
1074
        preemption_mode: Optional[PreemptionMode] = None,
1075
    ) -> PreemptionMode:
1076
1077
1078
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
1079
1080
        # (e.g., beam search), recomputation is not currently supported. In
        # such a case, we use swapping instead.
1081
1082
1083
1084
1085
1086
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
1087
        if self.user_specified_preemption_mode is None:
1088
            if seq_group.get_max_num_running_seqs() == 1:
1089
1090
1091
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
1092

1093
1094
1095
1096
1097
        elif self.user_specified_preemption_mode == "swap":
            preemption_mode = PreemptionMode.SWAP
        else:
            preemption_mode = PreemptionMode.RECOMPUTE

1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
        if self.num_cumulative_preemption % 50 == 0:
            logger.warning(
                "Sequence group %s is preempted by %s mode because there is "
                "not enough KV cache space. This can affect the end-to-end "
                "performance. Increase gpu_memory_utilization or "
                "tensor_parallel_size to provide more KV cache memory. "
                "total_num_cumulative_preemption=%d", seq_group.request_id,
                preemption_mode, self.num_cumulative_preemption + 1)
        self.num_cumulative_preemption += 1

1108
1109
1110
1111
1112
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
1113
            raise AssertionError("Invalid preemption mode.")
1114
        return preemption_mode
1115
1116
1117
1118
1119
1120
1121
1122
1123

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
1124
1125
            self.free_seq(seq)
            seq.reset_state_for_recompute()
1126
1127
1128
1129

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
1130
        blocks_to_swap_out: List[Tuple[int, int]],
1131
1132
1133
1134
1135
1136
    ) -> None:
        self._swap_out(seq_group, blocks_to_swap_out)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
1137
        blocks_to_swap_in: List[Tuple[int, int]],
1138
1139
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
1140
        blocks_to_swap_in.extend(mapping)
1141
1142
1143
1144
1145
1146
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
1147
        blocks_to_swap_out: List[Tuple[int, int]],
1148
    ) -> None:
1149
1150
1151
1152
1153
1154
        if not self.block_manager.can_swap_out(seq_group):
            # FIXME(woosuk): Abort the sequence group instead of aborting the
            # entire engine.
            raise RuntimeError(
                "Aborted due to the lack of CPU swap space. Please increase "
                "the swap space to avoid this error.")
1155
        mapping = self.block_manager.swap_out(seq_group)
1156
        blocks_to_swap_out.extend(mapping)
1157
1158
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED
1159

1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
    def _passed_delay(self, now: float) -> bool:
        if self.prev_prompt:
            self.last_prompt_latency = now - self.prev_time
        self.prev_time, self.prev_prompt = now, False
        # Delay scheduling prompts to let waiting queue fill up
        if self.scheduler_config.delay_factor > 0 and self.waiting:
            earliest_arrival_time = min(
                [e.metrics.arrival_time for e in self.waiting])
            passed_delay = (
                (now - earliest_arrival_time) >
                (self.scheduler_config.delay_factor * self.last_prompt_latency)
                or not self.running)
        else:
            passed_delay = True
        return passed_delay
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187

    def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
        """The number of slots to allocate per sequence per step, beyond known
        token ids. Speculative decoding uses these slots to store KV activations
        of tokens which may or may not be accepted.

        Speculative decoding does not yet support prefill, so we do not perform
        lookahead allocation for prefill.
        """
        if is_prefill:
            return 0

        return self.scheduler_config.num_lookahead_slots
1188
1189
1190

    def _get_num_new_tokens(self, seq_group: SequenceGroup,
                            status: SequenceStatus, enable_chunking: bool,
1191
                            budget: SchedulingBudget) -> int:
1192
1193
1194
1195
1196
1197
1198
        """Get the next new tokens to compute for a given sequence group
            that's in a given `status`.

        The API could chunk the number of tokens to compute based on `budget`
        if `enable_chunking` is True. If a sequence group has multiple
        sequences (e.g., running beam search), it means it is in decoding
        phase, so chunking doesn't happen.
1199
1200

        Returns 0 if the new token cannot be computed due to token budget.
1201
1202
1203
1204
1205
        """
        num_new_tokens = 0
        seqs = seq_group.get_seqs(status=status)
        for seq in seqs:
            num_new_tokens += seq.get_num_new_tokens()
1206
        assert num_new_tokens > 0
1207
1208
1209
1210
1211
1212
1213
        # Chunk if a running request cannot fit in.
        # If number of seq > 1, it means it is doing beam search in a
        # decode phase. Do not chunk in that case.
        if enable_chunking and len(seqs) == 1:
            num_new_tokens = min(num_new_tokens,
                                 budget.remaining_token_budget())
        return num_new_tokens