scheduler.py 53.2 KB
Newer Older
1
import enum
2
3
import os
import random
4
import time
5
from collections import deque
6
from dataclasses import dataclass, field
7
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
10
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
11
from vllm.core.policy import Policy, PolicyFactory
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
15
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
16
                           SequenceGroupMetadata, SequenceStatus)
Woosuk Kwon's avatar
Woosuk Kwon committed
17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
logger = init_logger(__name__)
19

20
21
22
23
24
25
26
# Test-only. If configured, decode is preempted with
# ARTIFICIAL_PREEMPTION_PROB% probability.
ENABLE_ARTIFICIAL_PREEMPT = bool(
    os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False))  # noqa
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500

Woosuk Kwon's avatar
Woosuk Kwon committed
27

28
29
30
31
32
33
34
35
36
37
38
39
40
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


41
42
@dataclass
class SchedulingBudget:
43
44
45
46
47
48
49
50
51
    """The available slots for scheduling.

    TODO(sang): Right now, the budget is request_id-aware meaning it can ignore
    budget update from the same request_id. It is because in normal scheduling
    path, we update RUNNING num_seqs ahead of time, meaning it could be
    updated more than once when scheduling RUNNING requests. Since this won't
    happen if we only have chunked prefill scheduling, we can remove this
    feature from the API when chunked prefill is enabled by default.
    """
52
53
    token_budget: int
    max_num_seqs: int
54
55
    _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
    _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
56
57
    _num_batched_tokens: int = 0
    _num_curr_seqs: int = 0
58
59

    def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
60
61
        assert num_new_tokens != 0
        assert num_new_seqs != 0
62
63
64
        return (self.num_batched_tokens + num_new_tokens <= self.token_budget
                and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)

65
66
67
68
    def remaining_token_budget(self):
        return self.token_budget - self.num_batched_tokens

    def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
69
        if req_id in self._request_ids_num_batched_tokens:
70
71
            return

72
        self._request_ids_num_batched_tokens.add(req_id)
73
74
75
76
        self._num_batched_tokens += num_batched_tokens

    def subtract_num_batched_tokens(self, req_id: str,
                                    num_batched_tokens: int):
77
78
        if req_id in self._request_ids_num_batched_tokens:
            self._request_ids_num_batched_tokens.remove(req_id)
79
80
81
            self._num_batched_tokens -= num_batched_tokens

    def add_num_seqs(self, req_id: str, num_curr_seqs: int):
82
        if req_id in self._request_ids_num_curr_seqs:
83
84
            return

85
        self._request_ids_num_curr_seqs.add(req_id)
86
87
88
        self._num_curr_seqs += num_curr_seqs

    def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
89
90
        if req_id in self._request_ids_num_curr_seqs:
            self._request_ids_num_curr_seqs.remove(req_id)
91
92
93
94
95
96
97
98
99
100
            self._num_curr_seqs -= num_curr_seqs

    @property
    def num_batched_tokens(self):
        return self._num_batched_tokens

    @property
    def num_curr_seqs(self):
        return self._num_curr_seqs

101

102
103
104
105
106
107
108
109
110
111
@dataclass
class ScheduledSequenceGroup:
    # A sequence group that's scheduled.
    seq_group: SequenceGroup
    # The total chunk size (number of tokens) to process for next iteration.
    # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
    # chunked, it can be smaller than that.
    token_chunk_size: int


112
@dataclass
113
class SchedulerOutputs:
114
    """The scheduling decision made from a scheduler."""
115
116
117
118
119
120
    # Scheduled sequence groups.
    scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
    # Number of prefill groups scheduled.
    num_prefill_groups: int
    # Total number of batched tokens.
    num_batched_tokens: int
121
122
123
124
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]]
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]]
125
126
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]]
127
128
129
130
    # Sequence groups that are going to be ignored.
    ignored_seq_groups: List[SequenceGroup]
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int
131
132
    # The number of requests in the running queue
    running_queue_size: int
133
    preempted: int
134
135

    def __post_init__(self):
136
        # Swap in and swap out should never happen at the same time.
137
        assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
138

139
        self.num_loras: int = len(self.lora_requests)
140
141
142
        if self.num_loras > 0:
            self._sort_by_lora_ids()

143
144
        self.num_prompt_adapters: int = len(self.prompt_adapter_requests)

145
    def is_empty(self) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
148
        # NOTE: We do not consider the ignored sequence groups.
        return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
                and not self.blocks_to_swap_out and not self.blocks_to_copy)
149

150
    def _sort_by_lora_ids(self):
151
152
153
        self.scheduled_seq_groups = sorted(
            self.scheduled_seq_groups,
            key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
154
155
156

    @property
    def lora_requests(self) -> Set[LoRARequest]:
157
158
159
160
161
        return {
            g.seq_group.lora_request
            for g in self.scheduled_seq_groups
            if g.seq_group.lora_request is not None
        }
162

163
164
165
166
167
168
169
170
    @property
    def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
        return {
            g.seq_group.prompt_adapter_request
            for g in self.scheduled_seq_groups
            if g.seq_group.prompt_adapter_request is not None
        }

171

172
@dataclass
173
174
175
176
177
178
179
180
181
182
183
class SchedulerRunningOutputs:
    """The requests that are scheduled from a running queue.

    Could contain prefill (prefill that's chunked) or decodes. If there's not
    enough memory, it can be preempted (for recompute) or swapped out.
    """
    # Selected sequences that are running and in a decoding phase.
    decode_seq_groups: List[SequenceGroup]
    # Selected sequences that are running and in a prefill phase.
    # I.e., it means the prefill has been chunked.
    prefill_seq_groups: List[SequenceGroup]
184
185
186
187
188
    # The preempted sequences.
    preempted: List[SequenceGroup]
    # Sequences that are swapped out.
    swapped_out: List[SequenceGroup]
    # The blocks to swap out.
189
    blocks_to_swap_out: List[Tuple[int, int]]
190
    # The blocks to copy.
191
    blocks_to_copy: List[Tuple[int, int]]
192
    # The number of slots for lookahead decoding.
193
194
195
    num_lookahead_slots: int

    @classmethod
196
197
198
199
    def create_empty(cls) -> "SchedulerRunningOutputs":
        return SchedulerRunningOutputs(
            decode_seq_groups=[],
            prefill_seq_groups=[],
200
201
            preempted=[],
            swapped_out=[],
202
            blocks_to_swap_out=[],
203
            blocks_to_copy=[],
204
205
206
207
208
209
            num_lookahead_slots=0,
        )


@dataclass
class SchedulerSwappedInOutputs:
210
211
212
213
214
215
216
217
218
219
    """The requests that are scheduled from a swap queue.

    Could contain prefill (prefill that's chunked) or decodes.
    """
    # Selected sequences that are going to be swapped in and is in a
    # decoding phase.
    decode_seq_groups: List[SequenceGroup]
    # Selected sequences that are going to be swapped in and in a prefill
    # phase. I.e., it means the prefill has been chunked.
    prefill_seq_groups: List[SequenceGroup]
220
    # The blocks to swap in.
221
    blocks_to_swap_in: List[Tuple[int, int]]
222
    # The blocks to copy.
223
    blocks_to_copy: List[Tuple[int, int]]
224
    # The number of slots for lookahead decoding.
225
    num_lookahead_slots: int
226
227
    # Infeasible sequence groups.
    infeasible_seq_groups: List[SequenceGroup]
228
229
230
231

    @classmethod
    def create_empty(cls) -> "SchedulerSwappedInOutputs":
        return SchedulerSwappedInOutputs(
232
233
            decode_seq_groups=[],
            prefill_seq_groups=[],
234
            blocks_to_swap_in=[],
235
            blocks_to_copy=[],
236
            num_lookahead_slots=0,
237
            infeasible_seq_groups=[],
238
239
240
241
242
        )


@dataclass
class SchedulerPrefillOutputs:
243
244
245
246
247
248
    """The requests that are scheduled from a waiting queue.

    Could contain a fresh prefill requests or preempted requests that need
    to be recomputed from scratch.
    """
    # Selected sequences for prefill.
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    seq_groups: List[SequenceGroup]
    # Ignored sequence groups.
    ignored_seq_groups: List[SequenceGroup]
    num_lookahead_slots: int

    @classmethod
    def create_empty(cls) -> "SchedulerPrefillOutputs":
        return SchedulerPrefillOutputs(
            seq_groups=[],
            ignored_seq_groups=[],
            num_lookahead_slots=0,
        )


Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
265
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
266
        self,
267
268
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
269
        lora_config: Optional[LoRAConfig],
270
        pipeline_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
271
    ) -> None:
272
273
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
274
275
276
277
        # Note for LoRA scheduling: the current policy is extremely
        # simple and NOT fair. It can lead to starvation of some
        # LoRAs. This should be improved in the future.
        self.lora_config = lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
278

279
280
281
282
283
284
        version = "v1"
        if self.scheduler_config.use_v2_block_manager:
            version = "v2"
        if self.scheduler_config.embedding_mode:
            version = "embedding"

285
        BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
286
            version)
287

288
289
290
291
292
293
294
295
        num_gpu_blocks = cache_config.num_gpu_blocks
        if num_gpu_blocks:
            num_gpu_blocks //= pipeline_parallel_size

        num_cpu_blocks = cache_config.num_cpu_blocks
        if num_cpu_blocks:
            num_cpu_blocks //= pipeline_parallel_size

Woosuk Kwon's avatar
Woosuk Kwon committed
296
        # Create the block space manager.
297
        self.block_manager = BlockSpaceManagerImpl(
298
            block_size=self.cache_config.block_size,
299
300
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
301
302
            sliding_window=self.cache_config.sliding_window,
            enable_caching=self.cache_config.enable_prefix_caching)
303

304
        # Sequence groups in the WAITING state.
305
        # Contain new prefill or preempted requests.
306
        self.waiting: Deque[SequenceGroup] = deque()
307
        # Sequence groups in the RUNNING state.
308
        # Contain decode requests.
309
        self.running: Deque[SequenceGroup] = deque()
310
        # Sequence groups in the SWAPPED state.
311
        # Contain decode requests that are swapped out.
312
        self.swapped: Deque[SequenceGroup] = deque()
Mor Zusman's avatar
Mor Zusman committed
313
314
315
        # Sequence groups finished requests ids since last step iteration.
        # It lets the model know that any state associated with these requests
        # can and must be released after the current step.
316
        # This is used to evict the finished requests from the Mamba cache.
Mor Zusman's avatar
Mor Zusman committed
317
        self._finished_requests_ids: List[str] = list()
318
319
320
321
322
323
        # Time at previous scheduling step
        self.prev_time = 0.0
        # Did we schedule a prompt at previous step?
        self.prev_prompt = False
        # Latency of the last prompt step
        self.last_prompt_latency = 0.0
324
325
        # preemption mode, RECOMPUTE or SWAP
        self.user_specified_preemption_mode = scheduler_config.preemption_mode
326

327
328
329
330
331
332
        # The following field is test-only. It is used to inject artificial
        # preemption.
        self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
        self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
                                       if self.enable_artificial_preemption
                                       else 0)
333
        self.num_cumulative_preemption: int = 0
334

335
336
337
338
    @property
    def lora_enabled(self) -> bool:
        return bool(self.lora_config)

339
340
341
342
343
    @property
    def num_decoding_tokens_per_seq(self) -> int:
        """The number of new tokens."""
        return 1

344
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
345
        # Add sequence groups to the waiting queue.
346
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
347

Antoni Baum's avatar
Antoni Baum committed
348
    def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
349
350
351
352
353
354
355
356
357
358
359
360
        """Aborts a sequence group with the given ID.

        Check if the sequence group with the given ID
            is present in any of the state queue.
        If present, remove the sequence group from the state queue.
            Also, if any of the sequences in the sequence group is not finished,
                free the sequence with status `FINISHED_ABORTED`.
        Otherwise, do nothing.

        Args:
            request_id: The ID(s) of the sequence group to abort.
        """
Antoni Baum's avatar
Antoni Baum committed
361
362
363
        if isinstance(request_id, str):
            request_id = (request_id, )
        request_ids = set(request_id)
364
        for state_queue in [self.waiting, self.running, self.swapped]:
ljss's avatar
ljss committed
365
            aborted_groups: List[SequenceGroup] = []
366
367
368
            for seq_group in state_queue:
                if not request_ids:
                    # Using 'break' here may add two extra iterations,
369
                    # but is acceptable to reduce complexity.
370
                    break
Antoni Baum's avatar
Antoni Baum committed
371
                if seq_group.request_id in request_ids:
372
373
                    # Appending aborted group into pending list.
                    aborted_groups.append(seq_group)
Antoni Baum's avatar
Antoni Baum committed
374
                    request_ids.remove(seq_group.request_id)
375
376
377
            for aborted_group in aborted_groups:
                # Remove the sequence group from the state queue.
                state_queue.remove(aborted_group)
378
                # Remove the aborted request from the Mamba cache.
379
                self._finished_requests_ids.append(aborted_group.request_id)
ljss's avatar
ljss committed
380
                for seq in aborted_group.get_seqs():
381
382
383
384
                    if seq.is_finished():
                        continue
                    seq.status = SequenceStatus.FINISHED_ABORTED
                    self.free_seq(seq)
385

386
    def has_unfinished_seqs(self) -> bool:
387
388
        return len(self.waiting) != 0 or len(self.running) != 0 or len(
            self.swapped) != 0
389

390
391
392
    def get_num_unfinished_seq_groups(self) -> int:
        return len(self.waiting) + len(self.running) + len(self.swapped)

Mor Zusman's avatar
Mor Zusman committed
393
394
395
396
397
398
    def get_and_reset_finished_requests_ids(self) -> List[str]:
        """Flushes the list of request ids of previously finished seq_groups."""
        finished_requests_ids = self._finished_requests_ids
        self._finished_requests_ids = list()
        return finished_requests_ids

399
    def _schedule_running(
400
401
402
403
404
        self,
        running_queue: deque,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
        policy: Policy,
405
406
407
        enable_chunking: bool = False,
    ) -> Tuple[deque, SchedulerRunningOutputs]:
        """Schedule sequence groups that are running.
408

409
        Running queue should include decode and chunked prefill requests.
Woosuk Kwon's avatar
Woosuk Kwon committed
410

411
412
413
414
415
416
417
418
        Args:
            running_queue: The queue that contains running requests (i.e.,
                decodes). The given arguments are NOT in-place modified.
            budget: The scheduling budget. The argument is in-place updated
                when any decodes are preempted.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any decodes are preempted.
            policy: The sorting policy to sort running_queue.
419
420
421
422
423
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
    
424
425
        Returns:
            A tuple of remaining running queue (should be always 0) after
426
            scheduling and SchedulerRunningOutputs.
427
428
        """
        # Blocks that need to be swapped or copied before model execution.
429
        blocks_to_swap_out: List[Tuple[int, int]] = []
430
        blocks_to_copy: List[Tuple[int, int]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
431

432
433
        decode_seq_groups: List[ScheduledSequenceGroup] = []
        prefill_seq_groups: List[ScheduledSequenceGroup] = []
434
435
        preempted: List[SequenceGroup] = []
        swapped_out: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
436
437
438

        # NOTE(woosuk): Preemption happens only when there is no available slot
        # to keep all the sequence groups in the RUNNING state.
439
440
        # In this case, the policy is responsible for deciding which sequence
        # groups to preempt.
441
442
443
444
        now = time.time()
        running_queue = policy.sort_by_priority(now, running_queue)
        while running_queue:
            seq_group = running_queue[0]
445
446
447
            num_running_tokens = self._get_num_new_tokens(
                seq_group, SequenceStatus.RUNNING, enable_chunking, budget)

448
449
            if num_running_tokens == 0:
                break
450
451

            running_queue.popleft()
452
            while not self._can_append_slots(seq_group):
453
454
                budget.subtract_num_batched_tokens(seq_group.request_id,
                                                   num_running_tokens)
455
                num_running_seqs = seq_group.get_max_num_running_seqs()
456
457
                budget.subtract_num_seqs(seq_group.request_id,
                                         num_running_seqs)
458
459
460

                if (curr_loras is not None and seq_group.lora_int_id > 0
                        and seq_group.lora_int_id in curr_loras):
461
                    curr_loras.remove(seq_group.lora_int_id)
462
463

                if running_queue:
464
                    # Preempt the lowest-priority sequence groups.
465
466
467
468
469
470
471
                    victim_seq_group = running_queue.pop()
                    preempted_mode = self._preempt(victim_seq_group,
                                                   blocks_to_swap_out)
                    if preempted_mode == PreemptionMode.RECOMPUTE:
                        preempted.append(victim_seq_group)
                    else:
                        swapped_out.append(victim_seq_group)
472
473
474
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
475
476
477
478
479
480
                    preempted_mode = self._preempt(seq_group,
                                                   blocks_to_swap_out)
                    if preempted_mode == PreemptionMode.RECOMPUTE:
                        preempted.append(seq_group)
                    else:
                        swapped_out.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
481
482
                    break
            else:
483
                self._append_slots(seq_group, blocks_to_copy)
484
485
486
487
488
489
490
491
492
493
494
495
                is_prefill = seq_group.is_prefill()
                if is_prefill:
                    prefill_seq_groups.append(
                        ScheduledSequenceGroup(
                            seq_group=seq_group,
                            token_chunk_size=num_running_tokens))
                else:
                    decode_seq_groups.append(
                        ScheduledSequenceGroup(seq_group=seq_group,
                                               token_chunk_size=1))
                budget.add_num_batched_tokens(seq_group.request_id,
                                              num_running_tokens)
496
497
498
499
500
501
502
                # OPTIMIZATION:  Note that get_max_num_running_seqs is
                # expensive. For the default scheduling chase where
                # enable_chunking is False, num_seqs are updated before running
                # this method, so we don't have to update it again here.
                if enable_chunking:
                    num_running_seqs = seq_group.get_max_num_running_seqs()
                    budget.add_num_seqs(seq_group.request_id, num_running_seqs)
503
504
505
506
507
508
                if curr_loras is not None and seq_group.lora_int_id > 0:
                    curr_loras.add(seq_group.lora_int_id)

        return running_queue, SchedulerRunningOutputs(
            decode_seq_groups=decode_seq_groups,
            prefill_seq_groups=prefill_seq_groups,
509
510
511
512
513
514
            preempted=preempted,
            swapped_out=swapped_out,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
            num_lookahead_slots=self._get_num_lookahead_slots(
                is_prefill=False))
515

516
517
518
519
520
521
    def _schedule_swapped(
        self,
        swapped_queue: deque,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
        policy: Policy,
522
        enable_chunking: bool = False,
523
524
    ) -> Tuple[deque, SchedulerSwappedInOutputs]:
        """Schedule sequence groups that are swapped out.
525

526
527
528
        It schedules swapped requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.
529

530
531
532
533
534
535
536
537
        Args:
            swapped_queue: The queue that contains swapped out requests.
                The given arguments are NOT in-place modified.
            budget: The scheduling budget. The argument is in-place updated
                when any requests are swapped in.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are swapped in.
            policy: The sorting policy to sort swapped_queue.
538
539
540
541
542
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.

543
544
545
546
547
        Returns:
            A tuple of remaining swapped_queue after scheduling and
            SchedulerSwappedInOutputs.
        """
        # Blocks that need to be swapped or copied before model execution.
548
        blocks_to_swap_in: List[Tuple[int, int]] = []
549
        blocks_to_copy: List[Tuple[int, int]] = []
550
551
        decode_seq_groups: List[ScheduledSequenceGroup] = []
        prefill_seq_groups: List[ScheduledSequenceGroup] = []
552
553
        now = time.time()
        swapped_queue = policy.sort_by_priority(now, swapped_queue)
554
        infeasible_seq_groups: List[SequenceGroup] = []
555

556
        leftover_swapped: Deque[SequenceGroup] = deque()
557
558
559
560
        while swapped_queue:
            seq_group = swapped_queue[0]

            # If the sequence group cannot be swapped in, stop.
561
562
563
            is_prefill = seq_group.is_prefill()
            alloc_status = self.block_manager.can_swap_in(
                seq_group, self._get_num_lookahead_slots(is_prefill))
564
            if alloc_status == AllocStatus.LATER:
565
                break
566
567
568
569
570
571
572
573
574
575
            elif alloc_status == AllocStatus.NEVER:
                logger.warning(
                    "Failing the request %s because there's not enough kv "
                    "cache blocks to run the entire sequence.",
                    seq_group.request_id)
                for seq in seq_group.get_seqs():
                    seq.status = SequenceStatus.FINISHED_IGNORED
                infeasible_seq_groups.append(seq_group)
                swapped_queue.popleft()
                continue
576
577
578
579

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
580
581
582
                assert curr_loras is not None
                assert self.lora_config is not None
                if (lora_int_id > 0 and (lora_int_id not in curr_loras)
583
584
585
586
587
588
589
590
591
592
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_swapped.appendleft(seq_group)
                    swapped_queue.popleft()
                    continue

            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
            num_new_seqs = seq_group.get_max_num_running_seqs()
593
594
595
            num_new_tokens = self._get_num_new_tokens(seq_group,
                                                      SequenceStatus.SWAPPED,
                                                      enable_chunking, budget)
596

597
598
599
            if (num_new_tokens == 0
                    or not budget.can_schedule(num_new_tokens=num_new_tokens,
                                               num_new_seqs=num_new_seqs)):
600
601
602
603
604
605
606
                break

            if lora_int_id > 0 and curr_loras is not None:
                curr_loras.add(lora_int_id)
            swapped_queue.popleft()
            self._swap_in(seq_group, blocks_to_swap_in)
            self._append_slots(seq_group, blocks_to_copy)
607
608
609
610
611
612
613
614
615
616
            is_prefill = seq_group.is_prefill()
            if is_prefill:
                prefill_seq_groups.append(
                    ScheduledSequenceGroup(seq_group,
                                           token_chunk_size=num_new_tokens))
            else:
                decode_seq_groups.append(
                    ScheduledSequenceGroup(seq_group, token_chunk_size=1))
            budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
617
618
619
620

        swapped_queue.extendleft(leftover_swapped)

        return swapped_queue, SchedulerSwappedInOutputs(
621
622
            decode_seq_groups=decode_seq_groups,
            prefill_seq_groups=prefill_seq_groups,
623
624
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_copy=blocks_to_copy,
625
            num_lookahead_slots=self._get_num_lookahead_slots(
626
627
628
                is_prefill=False),
            infeasible_seq_groups=infeasible_seq_groups,
        )
629

630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
    def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
        if self.scheduler_config.chunked_prefill_enabled:
            prompt_limit = self.scheduler_config.max_model_len
        else:
            prompt_limit = min(self.scheduler_config.max_model_len,
                               self.scheduler_config.max_num_batched_tokens)

        # Model is fine tuned with long context. Return the fine tuned max_len.
        if (seq_group.lora_request
                and seq_group.lora_request.long_lora_max_len):
            assert prompt_limit <= seq_group.lora_request.long_lora_max_len
            return seq_group.lora_request.long_lora_max_len
        else:
            return prompt_limit

645
646
647
648
649
    def _schedule_prefills(
        self,
        waiting_queue: deque,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
650
        enable_chunking: bool = False,
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
    ) -> Tuple[deque, SchedulerPrefillOutputs]:
        """Schedule sequence groups that are in prefill stage.

        Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
        as a new prefill (that starts from beginning -> most recently generated
        tokens).

        It schedules waiting requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.

        Args:
            waiting_queue: The queue that contains prefill requests.
                The given arguments are NOT in-place modified.
            budget: The scheduling budget. The argument is in-place updated
                when any requests are scheduled.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are scheduled.
669
670
671
672
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
673
674
675
676
677
678
679
680
681
682
683

        Returns:
            A tuple of remaining waiting_queue after scheduling and
            SchedulerSwappedInOutputs.
        """
        ignored_seq_groups: List[SequenceGroup] = []
        seq_groups: List[SequenceGroup] = []
        # We don't sort waiting queue because we assume it is sorted.
        # Copy the queue so that the input queue is not modified.
        waiting_queue = deque([s for s in waiting_queue])

684
        leftover_waiting_sequences: Deque[SequenceGroup] = deque()
685
686
687
688
689
690
691
        while self._passed_delay(time.time()) and waiting_queue:
            seq_group = waiting_queue[0]

            waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
            assert len(waiting_seqs) == 1, (
                "Waiting sequence group should have only one prompt "
                "sequence.")
692
693
694
695
696
697
698
            num_new_tokens = self._get_num_new_tokens(seq_group,
                                                      SequenceStatus.WAITING,
                                                      enable_chunking, budget)
            if not enable_chunking:
                num_prompt_tokens = waiting_seqs[0].get_len()
                assert num_new_tokens == num_prompt_tokens

699
700
            prompt_limit = self._get_prompt_limit(seq_group)
            if num_new_tokens > prompt_limit:
701
                logger.warning(
702
                    "Input prompt (%d tokens) is too long"
703
                    " and exceeds limit of %d", num_new_tokens, prompt_limit)
704
705
706
707
708
709
710
711
712
713
714
715
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

            # If the sequence group cannot be allocated, stop.
            can_allocate = self.block_manager.can_allocate(seq_group)
            if can_allocate == AllocStatus.LATER:
                break
            elif can_allocate == AllocStatus.NEVER:
                logger.warning(
716
717
718
                    "Input prompt (%d tokens) is too long"
                    " and exceeds the capacity of block_manager",
                    num_new_tokens)
719
720
721
722
723
724
725
726
727
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
728
729
                assert curr_loras is not None
                assert self.lora_config is not None
730
731
732
733
734
735
736
737
738
739
                if (self.lora_enabled and lora_int_id > 0
                        and lora_int_id not in curr_loras
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_waiting_sequences.appendleft(seq_group)
                    waiting_queue.popleft()
                    continue

            num_new_seqs = seq_group.get_max_num_running_seqs()
740
741
742
            if (num_new_tokens == 0
                    or not budget.can_schedule(num_new_tokens=num_new_tokens,
                                               num_new_seqs=num_new_seqs)):
743
744
745
746
747
748
                break

            # Can schedule this request.
            if curr_loras is not None and lora_int_id > 0:
                curr_loras.add(lora_int_id)
            waiting_queue.popleft()
749
            self._allocate_and_set_running(seq_group)
750
751
            seq_groups.append(
                ScheduledSequenceGroup(seq_group=seq_group,
752
753
754
                                       token_chunk_size=num_new_tokens))
            budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
755
756
757
758
759
760
761
762
763
764
765

        # Queue requests that couldn't be scheduled.
        waiting_queue.extendleft(leftover_waiting_sequences)
        if len(seq_groups) > 0:
            self.prev_prompt = True

        return waiting_queue, SchedulerPrefillOutputs(
            seq_groups=seq_groups,
            ignored_seq_groups=ignored_seq_groups,
            num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))

766
767
    def _schedule_default(self) -> SchedulerOutputs:
        """Schedule queued requests.
768
        
769
        The current policy is designed to optimize the throughput. First,
770
771
772
773
774
775
776
777
778
        it batches as many prefill requests as possible. And it schedules
        decodes. If there's a pressure on GPU memory, decode requests can
        be swapped or preempted.
        """
        # Include running requests to the budget.
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
779
780
781
782
783
        # Make sure we include num running seqs before scheduling prefill,
        # so that we don't schedule beyond max_num_seqs for prefill.
        for seq_group in self.running:
            budget.add_num_seqs(seq_group.request_id,
                                seq_group.get_max_num_running_seqs())
784
        curr_loras = set(
785
786
            seq_group.lora_int_id for seq_group in self.running
            if seq_group.lora_int_id > 0) if self.lora_enabled else None
787
788
789

        remaining_waiting, prefills = (self.waiting,
                                       SchedulerPrefillOutputs.create_empty())
790
791
        remaining_running, running_scheduled = (
            self.running, SchedulerRunningOutputs.create_empty())
792
793
794
795
796
797
        remaining_swapped, swapped_in = (
            self.swapped, SchedulerSwappedInOutputs.create_empty())

        # If any requests are swapped, prioritized swapped requests.
        if not self.swapped:
            remaining_waiting, prefills = self._schedule_prefills(
798
                self.waiting, budget, curr_loras, enable_chunking=False)
799

800
        fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
801
        # Don't schedule decodes if prefills are scheduled.
802
803
        # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
        # only contains decode requests, not chunked prefills.
804
        if len(prefills.seq_groups) == 0:
805
806
807
808
809
810
811
            remaining_running, running_scheduled = self._schedule_running(
                self.running,
                budget,
                curr_loras,
                fcfs_policy,
                enable_chunking=False)

812
813
            # If any sequence group is preempted, do not swap in any sequence
            # group. because it means there's no slot for new running requests.
814
815
            if len(running_scheduled.preempted) + len(
                    running_scheduled.swapped_out) == 0:
816
                remaining_swapped, swapped_in = self._schedule_swapped(
817
                    self.swapped, budget, curr_loras, fcfs_policy)
818
819
820
821
822
823
824

        assert (budget.num_batched_tokens <=
                self.scheduler_config.max_num_batched_tokens)
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
        self.waiting = remaining_waiting
825
        self.waiting.extendleft(running_scheduled.preempted)
826
827
828
        # Update new running requests.
        self.running = remaining_running
        self.running.extend([s.seq_group for s in prefills.seq_groups])
829
830
831
832
        self.running.extend(
            [s.seq_group for s in running_scheduled.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.decode_seq_groups])
833
834
        # Update swapped requests.
        self.swapped = remaining_swapped
835
        self.swapped.extend(running_scheduled.swapped_out)
836
837
        preempted = (len(running_scheduled.preempted) +
                     len(running_scheduled.swapped_out))
838

839
840
841
842
        # There should be no prefill from running queue because this policy
        # doesn't allow chunked prefills.
        assert len(running_scheduled.prefill_seq_groups) == 0
        assert len(swapped_in.prefill_seq_groups) == 0
843
        return SchedulerOutputs(
844
845
846
            scheduled_seq_groups=(prefills.seq_groups +
                                  running_scheduled.decode_seq_groups +
                                  swapped_in.decode_seq_groups),
847
848
849
            num_prefill_groups=len(prefills.seq_groups),
            num_batched_tokens=budget.num_batched_tokens,
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
850
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
851
852
            blocks_to_copy=running_scheduled.blocks_to_copy +
            swapped_in.blocks_to_copy,
853
854
            ignored_seq_groups=prefills.ignored_seq_groups +
            swapped_in.infeasible_seq_groups,
855
            num_lookahead_slots=running_scheduled.num_lookahead_slots,
856
            running_queue_size=len(self.running),
857
            preempted=preempted,
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
        )

    def _schedule_chunked_prefill(self):
        """Schedule queued requests.
        
        Chunked prefill allows to chunk prefill requests, batch them together
        with decode requests. This policy 1. schedule as many decoding requests
        as possible. 2. schedule chunked prefill requests that are not
        finished. 3. schedule swapped request. 4. schedule new prefill
        requests.

        The policy can sustain the high GPU utilization because it can put
        prefill and decodes requests to the same batch, while it improves
        inter token latency because decodes requests don't need to blocked
        by prefill requests.
        """
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
878
        curr_loras: Set[int] = set()
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930

        remaining_waiting, prefills = (self.waiting,
                                       SchedulerPrefillOutputs.create_empty())
        remaining_running, running_scheduled = (
            self.running, SchedulerRunningOutputs.create_empty())
        remaining_swapped, swapped_in = (
            self.swapped, SchedulerSwappedInOutputs.create_empty())

        # Decoding should be always scheduled first by fcfs.
        fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
        remaining_running, running_scheduled = self._schedule_running(
            self.running,
            budget,
            curr_loras,
            fcfs_policy,
            enable_chunking=True)

        # Schedule swapped out requests.
        # If preemption happens, it means we don't have space for swap-in.
        if len(running_scheduled.preempted) + len(
                running_scheduled.swapped_out) == 0:
            remaining_swapped, swapped_in = self._schedule_swapped(
                self.swapped, budget, curr_loras, fcfs_policy)

        # Schedule new prefills.
        remaining_waiting, prefills = self._schedule_prefills(
            self.waiting, budget, curr_loras, enable_chunking=True)

        assert (budget.num_batched_tokens <=
                self.scheduler_config.max_num_batched_tokens)
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
        self.waiting = remaining_waiting
        self.waiting.extendleft(running_scheduled.preempted)
        # Update new running requests.
        self.running = remaining_running
        self.running.extend([s.seq_group for s in prefills.seq_groups])
        self.running.extend(
            [s.seq_group for s in running_scheduled.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in running_scheduled.prefill_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.prefill_seq_groups])
        # Update swapped requests.
        self.swapped = remaining_swapped
        self.swapped.extend(running_scheduled.swapped_out)
        return SchedulerOutputs(
            scheduled_seq_groups=(prefills.seq_groups +
                                  running_scheduled.prefill_seq_groups +
931
932
933
                                  swapped_in.prefill_seq_groups +
                                  running_scheduled.decode_seq_groups +
                                  swapped_in.decode_seq_groups),
934
935
936
937
938
939
            num_prefill_groups=(len(prefills.seq_groups) +
                                len(swapped_in.prefill_seq_groups) +
                                len(running_scheduled.prefill_seq_groups)),
            num_batched_tokens=budget.num_batched_tokens,
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
940
941
            blocks_to_copy=running_scheduled.blocks_to_copy +
            swapped_in.blocks_to_copy,
942
943
            ignored_seq_groups=prefills.ignored_seq_groups +
            swapped_in.infeasible_seq_groups,
944
            num_lookahead_slots=running_scheduled.num_lookahead_slots,
945
            running_queue_size=len(self.running),
946
947
            preempted=(len(running_scheduled.preempted) +
                       len(running_scheduled.swapped_out)),
948
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
949

950
951
952
953
954
955
956
    def _schedule(self) -> SchedulerOutputs:
        """Schedule queued requests."""
        if self.scheduler_config.chunked_prefill_enabled:
            return self._schedule_chunked_prefill()
        else:
            return self._schedule_default()

957
958
959
960
    def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
        """Determine whether or not we have enough space in the KV cache to
        continue generation of the sequence group.
        """
961
962
963
964
965
966
967
        # It is True only for testing case to trigger artificial preemption.
        if (self.enable_artificial_preemption
                and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
                and self.artificial_preempt_cnt > 0):
            self.artificial_preempt_cnt -= 1
            return False

968
969
970
971
972
973
974
975
        # Appending slots only occurs in decoding.
        is_prefill = False

        return self.block_manager.can_append_slots(
            seq_group=seq_group,
            num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
976
    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
977
978
979
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
Woosuk Kwon's avatar
Woosuk Kwon committed
980
        scheduler_outputs = self._schedule()
981
        now = time.time()
982
983

        # Create input data structures.
984
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
985
986
        for i, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
987
988
            seq_group = scheduled_seq_group.seq_group
            token_chunk_size = scheduled_seq_group.token_chunk_size
989
990
            seq_group.maybe_set_first_scheduled_time(now)

991
            # seq_id -> SequenceData
Light Lin's avatar
Light Lin committed
992
            seq_data: Dict[int, SequenceData] = {}
993
            # seq_id -> physical block numbers
994
            block_tables: Dict[int, List[int]] = {}
995

996
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
997
                seq_id = seq.seq_id
998
                seq_data[seq_id] = seq.data
999
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
1000
                self.block_manager.access_all_blocks_in_seq(seq, now)
1001

1002
1003
1004
1005
            common_computed_block_nums = (
                self.block_manager.get_common_computed_block_ids(
                    seq_group.get_seqs(status=SequenceStatus.RUNNING)))

1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
            do_sample = True
            if seq_group.is_prefill():
                seqs = seq_group.get_seqs()
                # Prefill has only 1 sequence.
                assert len(seqs) == 1
                # In the next iteration, all prompt tokens are not computed.
                # It means the prefill is chunked, and we don't need sampling.
                # NOTE: We use get_len instead of get_prompt_len because when
                # a sequence is preempted, prefill includes previous generated
                # output tokens.
                if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
                        seqs[0].data.get_len()):
                    do_sample = False

1020
1021
            # It assumes the scheduled_seq_groups is ordered by
            # prefill < decoding.
1022
            is_prompt = seq_group.is_prefill()
1023
            seq_group_metadata = SequenceGroupMetadata(
1024
                request_id=seq_group.request_id,
1025
                is_prompt=is_prompt,
1026
                seq_data=seq_data,
1027
                sampling_params=seq_group.sampling_params,
1028
                block_tables=block_tables,
1029
                do_sample=do_sample,
1030
                pooling_params=seq_group.pooling_params,
1031
                token_chunk_size=token_chunk_size,
1032
                lora_request=seq_group.lora_request,
1033
                computed_block_nums=common_computed_block_nums,
1034
1035
1036
1037
1038
                # `multi_modal_data` will only be present for the 1st comm
                # between engine and worker.
                # the subsequent comms can still use delta, but
                # `multi_modal_data` will be None.
                multi_modal_data=seq_group.multi_modal_data
1039
                if scheduler_outputs.num_prefill_groups > 0 else None,
1040
                prompt_adapter_request=seq_group.prompt_adapter_request,
1041
            )
1042
            seq_group_metadata_list.append(seq_group_metadata)
1043
1044
1045
1046
1047

        # Now that the batch has been created, we can assume all blocks in the
        # batch will have been computed before the next scheduling invocation.
        # This is because the engine assumes that a failure in model execution
        # will crash the vLLM instance / will not retry.
1048
1049
1050
        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            self.block_manager.mark_blocks_as_computed(
                scheduled_seq_group.seq_group)
1051

Woosuk Kwon's avatar
Woosuk Kwon committed
1052
        return seq_group_metadata_list, scheduler_outputs
1053

1054
1055
    def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
        self.block_manager.fork(parent_seq, child_seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
1056

1057
    def free_seq(self, seq: Sequence) -> None:
1058
        """Free a sequence from a block table."""
1059
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
1060

1061
    def free_finished_seq_groups(self) -> None:
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
        remaining: Deque[SequenceGroup] = deque()
        for seq_group in self.running:
            if seq_group.is_finished():
                # Add the finished requests to the finished requests list.
                # This list will be used to update the Mamba cache in the
                # next step.
                self._finished_requests_ids.append(seq_group.request_id)
            else:
                remaining.append(seq_group)
        self.running = remaining
Woosuk Kwon's avatar
Woosuk Kwon committed
1072

1073
    def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
1074
        self.block_manager.allocate(seq_group)
1075
        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
1076
1077
            seq.status = SequenceStatus.RUNNING

1078
    def _append_slots(
1079
1080
        self,
        seq_group: SequenceGroup,
1081
        blocks_to_copy: List[Tuple[int, int]],
1082
    ) -> None:
1083
1084
1085
1086
1087
        """Appends new slots to the sequences in the given sequence group.

        Args:
            seq_group (SequenceGroup): The sequence group containing the
                sequences to append slots to.
1088
1089
1090
1091
1092
            blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
                ints, the first int is the source block index, and the second
                int is the destination block index. This list is updated with
                the new source and destination block indices for the appended
                slots.
1093
1094
1095
        """
        num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)

1096
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
1097
            cows = self.block_manager.append_slots(seq, num_lookahead_slots)
1098
            blocks_to_copy.extend(cows)
1099
1100
1101
1102

    def _preempt(
        self,
        seq_group: SequenceGroup,
1103
        blocks_to_swap_out: List[Tuple[int, int]],
1104
        preemption_mode: Optional[PreemptionMode] = None,
1105
    ) -> PreemptionMode:
1106
1107
1108
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
1109
1110
        # (e.g., beam search), recomputation is not currently supported. In
        # such a case, we use swapping instead.
1111
1112
1113
1114
1115
1116
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
1117
        if self.user_specified_preemption_mode is None:
1118
            if seq_group.get_max_num_running_seqs() == 1:
1119
1120
1121
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
1122

1123
1124
1125
1126
1127
        elif self.user_specified_preemption_mode == "swap":
            preemption_mode = PreemptionMode.SWAP
        else:
            preemption_mode = PreemptionMode.RECOMPUTE

1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        if self.num_cumulative_preemption % 50 == 0:
            logger.warning(
                "Sequence group %s is preempted by %s mode because there is "
                "not enough KV cache space. This can affect the end-to-end "
                "performance. Increase gpu_memory_utilization or "
                "tensor_parallel_size to provide more KV cache memory. "
                "total_num_cumulative_preemption=%d", seq_group.request_id,
                preemption_mode, self.num_cumulative_preemption + 1)
        self.num_cumulative_preemption += 1

1138
1139
1140
1141
1142
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
1143
            raise AssertionError("Invalid preemption mode.")
1144
        return preemption_mode
1145
1146
1147
1148
1149
1150
1151
1152
1153

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
1154
1155
            self.free_seq(seq)
            seq.reset_state_for_recompute()
1156
1157
1158
1159

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
1160
        blocks_to_swap_out: List[Tuple[int, int]],
1161
1162
1163
1164
1165
1166
    ) -> None:
        self._swap_out(seq_group, blocks_to_swap_out)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
1167
        blocks_to_swap_in: List[Tuple[int, int]],
1168
1169
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
1170
        blocks_to_swap_in.extend(mapping)
1171
1172
1173
1174
1175
1176
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
1177
        blocks_to_swap_out: List[Tuple[int, int]],
1178
    ) -> None:
1179
1180
1181
1182
1183
1184
        if not self.block_manager.can_swap_out(seq_group):
            # FIXME(woosuk): Abort the sequence group instead of aborting the
            # entire engine.
            raise RuntimeError(
                "Aborted due to the lack of CPU swap space. Please increase "
                "the swap space to avoid this error.")
1185
        mapping = self.block_manager.swap_out(seq_group)
1186
        blocks_to_swap_out.extend(mapping)
1187
1188
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED
1189

1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
    def _passed_delay(self, now: float) -> bool:
        if self.prev_prompt:
            self.last_prompt_latency = now - self.prev_time
        self.prev_time, self.prev_prompt = now, False
        # Delay scheduling prompts to let waiting queue fill up
        if self.scheduler_config.delay_factor > 0 and self.waiting:
            earliest_arrival_time = min(
                [e.metrics.arrival_time for e in self.waiting])
            passed_delay = (
                (now - earliest_arrival_time) >
                (self.scheduler_config.delay_factor * self.last_prompt_latency)
                or not self.running)
        else:
            passed_delay = True
        return passed_delay
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217

    def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
        """The number of slots to allocate per sequence per step, beyond known
        token ids. Speculative decoding uses these slots to store KV activations
        of tokens which may or may not be accepted.

        Speculative decoding does not yet support prefill, so we do not perform
        lookahead allocation for prefill.
        """
        if is_prefill:
            return 0

        return self.scheduler_config.num_lookahead_slots
1218
1219
1220

    def _get_num_new_tokens(self, seq_group: SequenceGroup,
                            status: SequenceStatus, enable_chunking: bool,
1221
                            budget: SchedulingBudget) -> int:
1222
1223
1224
1225
1226
1227
1228
        """Get the next new tokens to compute for a given sequence group
            that's in a given `status`.

        The API could chunk the number of tokens to compute based on `budget`
        if `enable_chunking` is True. If a sequence group has multiple
        sequences (e.g., running beam search), it means it is in decoding
        phase, so chunking doesn't happen.
1229
1230

        Returns 0 if the new token cannot be computed due to token budget.
1231
1232
1233
1234
1235
        """
        num_new_tokens = 0
        seqs = seq_group.get_seqs(status=status)
        for seq in seqs:
            num_new_tokens += seq.get_num_new_tokens()
1236
        assert num_new_tokens > 0
1237
1238
1239
1240
1241
1242
1243
        # Chunk if a running request cannot fit in.
        # If number of seq > 1, it means it is doing beam search in a
        # decode phase. Do not chunk in that case.
        if enable_chunking and len(seqs) == 1:
            num_new_tokens = min(num_new_tokens,
                                 budget.remaining_token_budget())
        return num_new_tokens