scheduler.py 56.1 KB
Newer Older
1
import enum
2
3
import os
import random
4
import time
5
from collections import deque
6
from dataclasses import dataclass, field
7
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
10
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.logger import init_logger
12
from vllm.lora.request import LoRARequest
13
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
15
                           SequenceGroupMetadata, SequenceStatus)
16
from vllm.utils import PyObjectCache
Woosuk Kwon's avatar
Woosuk Kwon committed
17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
logger = init_logger(__name__)
19

20
21
22
23
24
25
26
# Test-only. If configured, decode is preempted with
# ARTIFICIAL_PREEMPTION_PROB% probability.
ENABLE_ARTIFICIAL_PREEMPT = bool(
    os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False))  # noqa
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500

Woosuk Kwon's avatar
Woosuk Kwon committed
27

28
29
30
31
32
33
34
35
36
37
38
39
40
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


41
42
@dataclass
class SchedulingBudget:
43
44
45
46
47
48
49
50
51
    """The available slots for scheduling.

    TODO(sang): Right now, the budget is request_id-aware meaning it can ignore
    budget update from the same request_id. It is because in normal scheduling
    path, we update RUNNING num_seqs ahead of time, meaning it could be
    updated more than once when scheduling RUNNING requests. Since this won't
    happen if we only have chunked prefill scheduling, we can remove this
    feature from the API when chunked prefill is enabled by default.
    """
52
53
    token_budget: int
    max_num_seqs: int
54
55
    _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
    _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
56
57
    _num_batched_tokens: int = 0
    _num_curr_seqs: int = 0
58
59

    def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
60
61
        assert num_new_tokens != 0
        assert num_new_seqs != 0
62
63
64
        return (self.num_batched_tokens + num_new_tokens <= self.token_budget
                and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)

65
66
67
68
    def remaining_token_budget(self):
        return self.token_budget - self.num_batched_tokens

    def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
69
        if req_id in self._request_ids_num_batched_tokens:
70
71
            return

72
        self._request_ids_num_batched_tokens.add(req_id)
73
74
75
76
        self._num_batched_tokens += num_batched_tokens

    def subtract_num_batched_tokens(self, req_id: str,
                                    num_batched_tokens: int):
77
78
        if req_id in self._request_ids_num_batched_tokens:
            self._request_ids_num_batched_tokens.remove(req_id)
79
80
81
            self._num_batched_tokens -= num_batched_tokens

    def add_num_seqs(self, req_id: str, num_curr_seqs: int):
82
        if req_id in self._request_ids_num_curr_seqs:
83
84
            return

85
        self._request_ids_num_curr_seqs.add(req_id)
86
87
88
        self._num_curr_seqs += num_curr_seqs

    def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
89
90
        if req_id in self._request_ids_num_curr_seqs:
            self._request_ids_num_curr_seqs.remove(req_id)
91
92
93
94
95
96
97
98
99
100
            self._num_curr_seqs -= num_curr_seqs

    @property
    def num_batched_tokens(self):
        return self._num_batched_tokens

    @property
    def num_curr_seqs(self):
        return self._num_curr_seqs

101

102
103
104
105
106
107
108
109
110
111
@dataclass
class ScheduledSequenceGroup:
    # A sequence group that's scheduled.
    seq_group: SequenceGroup
    # The total chunk size (number of tokens) to process for next iteration.
    # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
    # chunked, it can be smaller than that.
    token_chunk_size: int


112
@dataclass
113
class SchedulerOutputs:
114
    """The scheduling decision made from a scheduler."""
115
116
117
118
119
120
    # Scheduled sequence groups.
    scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
    # Number of prefill groups scheduled.
    num_prefill_groups: int
    # Total number of batched tokens.
    num_batched_tokens: int
121
122
123
124
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]]
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]]
125
126
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]]
127
128
129
130
    # Sequence groups that are going to be ignored.
    ignored_seq_groups: List[SequenceGroup]
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int
131
132
    # The number of requests in the running queue
    running_queue_size: int
133
    preempted: int
134
135

    def __post_init__(self):
136
        # Swap in and swap out should never happen at the same time.
137
        assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
138

139
        self.num_loras: int = len(self.lora_requests)
140
141
142
        if self.num_loras > 0:
            self._sort_by_lora_ids()

143
144
        self.num_prompt_adapters: int = len(self.prompt_adapter_requests)

145
    def is_empty(self) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
148
        # NOTE: We do not consider the ignored sequence groups.
        return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
                and not self.blocks_to_swap_out and not self.blocks_to_copy)
149

150
    def _sort_by_lora_ids(self):
151
152
153
        self.scheduled_seq_groups = sorted(
            self.scheduled_seq_groups,
            key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
154
155
156

    @property
    def lora_requests(self) -> Set[LoRARequest]:
157
158
159
160
161
        return {
            g.seq_group.lora_request
            for g in self.scheduled_seq_groups
            if g.seq_group.lora_request is not None
        }
162

163
164
165
166
167
168
169
170
    @property
    def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
        return {
            g.seq_group.prompt_adapter_request
            for g in self.scheduled_seq_groups
            if g.seq_group.prompt_adapter_request is not None
        }

171

172
@dataclass
173
174
175
176
177
178
179
class SchedulerRunningOutputs:
    """The requests that are scheduled from a running queue.

    Could contain prefill (prefill that's chunked) or decodes. If there's not
    enough memory, it can be preempted (for recompute) or swapped out.
    """
    # Selected sequences that are running and in a decoding phase.
180
    decode_seq_groups: List[ScheduledSequenceGroup]
181
182
    # Selected sequences that are running and in a prefill phase.
    # I.e., it means the prefill has been chunked.
183
    prefill_seq_groups: List[ScheduledSequenceGroup]
184
185
186
187
188
    # The preempted sequences.
    preempted: List[SequenceGroup]
    # Sequences that are swapped out.
    swapped_out: List[SequenceGroup]
    # The blocks to swap out.
189
    blocks_to_swap_out: List[Tuple[int, int]]
190
    # The blocks to copy.
191
    blocks_to_copy: List[Tuple[int, int]]
192
    # The number of slots for lookahead decoding.
193
194
    num_lookahead_slots: int

195
196
197
198
    # Optimization for fast-access to seq_group lists
    decode_seq_groups_list: List[SequenceGroup]
    prefill_seq_groups_list: List[SequenceGroup]

199
    @classmethod
200
201
202
203
    def create_empty(cls) -> "SchedulerRunningOutputs":
        return SchedulerRunningOutputs(
            decode_seq_groups=[],
            prefill_seq_groups=[],
204
205
            preempted=[],
            swapped_out=[],
206
            blocks_to_swap_out=[],
207
            blocks_to_copy=[],
208
            num_lookahead_slots=0,
209
210
            decode_seq_groups_list=[],
            prefill_seq_groups_list=[],
211
212
213
214
215
        )


@dataclass
class SchedulerSwappedInOutputs:
216
217
218
219
220
221
222
223
224
225
    """The requests that are scheduled from a swap queue.

    Could contain prefill (prefill that's chunked) or decodes.
    """
    # Selected sequences that are going to be swapped in and is in a
    # decoding phase.
    decode_seq_groups: List[SequenceGroup]
    # Selected sequences that are going to be swapped in and in a prefill
    # phase. I.e., it means the prefill has been chunked.
    prefill_seq_groups: List[SequenceGroup]
226
    # The blocks to swap in.
227
    blocks_to_swap_in: List[Tuple[int, int]]
228
    # The blocks to copy.
229
    blocks_to_copy: List[Tuple[int, int]]
230
    # The number of slots for lookahead decoding.
231
    num_lookahead_slots: int
232
233
    # Infeasible sequence groups.
    infeasible_seq_groups: List[SequenceGroup]
234
235
236
237

    @classmethod
    def create_empty(cls) -> "SchedulerSwappedInOutputs":
        return SchedulerSwappedInOutputs(
238
239
            decode_seq_groups=[],
            prefill_seq_groups=[],
240
            blocks_to_swap_in=[],
241
            blocks_to_copy=[],
242
            num_lookahead_slots=0,
243
            infeasible_seq_groups=[],
244
245
246
247
248
        )


@dataclass
class SchedulerPrefillOutputs:
249
250
251
252
253
254
    """The requests that are scheduled from a waiting queue.

    Could contain a fresh prefill requests or preempted requests that need
    to be recomputed from scratch.
    """
    # Selected sequences for prefill.
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    seq_groups: List[SequenceGroup]
    # Ignored sequence groups.
    ignored_seq_groups: List[SequenceGroup]
    num_lookahead_slots: int

    @classmethod
    def create_empty(cls) -> "SchedulerPrefillOutputs":
        return SchedulerPrefillOutputs(
            seq_groups=[],
            ignored_seq_groups=[],
            num_lookahead_slots=0,
        )


269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def seq_group_metadata_builder():
    return SequenceGroupMetadata(request_id="",
                                 is_prompt=False,
                                 seq_data={},
                                 sampling_params=None,
                                 block_tables={})


def scheduler_running_outputs_builder():
    return SchedulerRunningOutputs(decode_seq_groups=[],
                                   prefill_seq_groups=[],
                                   preempted=[],
                                   swapped_out=[],
                                   blocks_to_swap_out=[],
                                   blocks_to_copy=[],
                                   num_lookahead_slots=0,
                                   prefill_seq_groups_list=[],
                                   decode_seq_groups_list=[])


def scheduled_seq_group_builder():
    return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)


Woosuk Kwon's avatar
Woosuk Kwon committed
293
294
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
295
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
296
        self,
297
298
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
299
        lora_config: Optional[LoRAConfig],
300
        pipeline_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
301
    ) -> None:
302
303
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
304
305
306
307
        # Note for LoRA scheduling: the current policy is extremely
        # simple and NOT fair. It can lead to starvation of some
        # LoRAs. This should be improved in the future.
        self.lora_config = lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
308

309
310
311
312
313
314
        version = "v1"
        if self.scheduler_config.use_v2_block_manager:
            version = "v2"
        if self.scheduler_config.embedding_mode:
            version = "embedding"

315
        BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
316
            version)
317

318
319
320
321
322
323
324
325
        num_gpu_blocks = cache_config.num_gpu_blocks
        if num_gpu_blocks:
            num_gpu_blocks //= pipeline_parallel_size

        num_cpu_blocks = cache_config.num_cpu_blocks
        if num_cpu_blocks:
            num_cpu_blocks //= pipeline_parallel_size

Woosuk Kwon's avatar
Woosuk Kwon committed
326
        # Create the block space manager.
327
        self.block_manager = BlockSpaceManagerImpl(
328
            block_size=self.cache_config.block_size,
329
330
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
331
332
            sliding_window=self.cache_config.sliding_window,
            enable_caching=self.cache_config.enable_prefix_caching)
333

334
        # Sequence groups in the WAITING state.
335
        # Contain new prefill or preempted requests.
336
        self.waiting: Deque[SequenceGroup] = deque()
337
        # Sequence groups in the RUNNING state.
338
        # Contain decode requests.
339
        self.running: Deque[SequenceGroup] = deque()
340
        # Sequence groups in the SWAPPED state.
341
        # Contain decode requests that are swapped out.
342
        self.swapped: Deque[SequenceGroup] = deque()
Mor Zusman's avatar
Mor Zusman committed
343
344
345
        # Sequence groups finished requests ids since last step iteration.
        # It lets the model know that any state associated with these requests
        # can and must be released after the current step.
346
        # This is used to evict the finished requests from the Mamba cache.
Mor Zusman's avatar
Mor Zusman committed
347
        self._finished_requests_ids: List[str] = list()
348
349
350
351
352
353
        # Time at previous scheduling step
        self.prev_time = 0.0
        # Did we schedule a prompt at previous step?
        self.prev_prompt = False
        # Latency of the last prompt step
        self.last_prompt_latency = 0.0
354
355
        # preemption mode, RECOMPUTE or SWAP
        self.user_specified_preemption_mode = scheduler_config.preemption_mode
356

357
358
359
360
361
362
        # The following field is test-only. It is used to inject artificial
        # preemption.
        self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
        self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
                                       if self.enable_artificial_preemption
                                       else 0)
363
        self.num_cumulative_preemption: int = 0
364

365
366
367
368
369
370
371
372
        # Used to cache python objects
        self._seq_group_metadata_cache: PyObjectCache = PyObjectCache(
            seq_group_metadata_builder)
        self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
            scheduler_running_outputs_builder)
        self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
            scheduled_seq_group_builder)

373
374
375
376
    @property
    def lora_enabled(self) -> bool:
        return bool(self.lora_config)

377
378
379
380
381
    @property
    def num_decoding_tokens_per_seq(self) -> int:
        """The number of new tokens."""
        return 1

382
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
383
        # Add sequence groups to the waiting queue.
384
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
385

386
387
388
389
390
391
392
393
394
395
    def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
        # Add sequence groups to the running queue.
        # Only for testing purposes.
        self.running.append(seq_group)

    def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
        # Add sequence groups to the swapped queue.
        # Only for testing purposes.
        self.swapped.append(seq_group)

Antoni Baum's avatar
Antoni Baum committed
396
    def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
397
398
399
400
401
402
403
404
405
406
407
408
        """Aborts a sequence group with the given ID.

        Check if the sequence group with the given ID
            is present in any of the state queue.
        If present, remove the sequence group from the state queue.
            Also, if any of the sequences in the sequence group is not finished,
                free the sequence with status `FINISHED_ABORTED`.
        Otherwise, do nothing.

        Args:
            request_id: The ID(s) of the sequence group to abort.
        """
Antoni Baum's avatar
Antoni Baum committed
409
410
411
        if isinstance(request_id, str):
            request_id = (request_id, )
        request_ids = set(request_id)
412
        for state_queue in [self.waiting, self.running, self.swapped]:
ljss's avatar
ljss committed
413
            aborted_groups: List[SequenceGroup] = []
414
415
416
            for seq_group in state_queue:
                if not request_ids:
                    # Using 'break' here may add two extra iterations,
417
                    # but is acceptable to reduce complexity.
418
                    break
Antoni Baum's avatar
Antoni Baum committed
419
                if seq_group.request_id in request_ids:
420
421
                    # Appending aborted group into pending list.
                    aborted_groups.append(seq_group)
Antoni Baum's avatar
Antoni Baum committed
422
                    request_ids.remove(seq_group.request_id)
423
424
425
            for aborted_group in aborted_groups:
                # Remove the sequence group from the state queue.
                state_queue.remove(aborted_group)
426
                # Remove the aborted request from the Mamba cache.
427
                self._finished_requests_ids.append(aborted_group.request_id)
ljss's avatar
ljss committed
428
                for seq in aborted_group.get_seqs():
429
430
431
432
                    if seq.is_finished():
                        continue
                    seq.status = SequenceStatus.FINISHED_ABORTED
                    self.free_seq(seq)
433

434
435
436
437
438
439
440
441
442
443
444
445
446
                self._free_seq_group_cross_attn_blocks(aborted_group)

    def _free_seq_group_cross_attn_blocks(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        """
        Free a sequence group from a cross-attention block table.
        Has no effect on decoder-only models.
        """
        if seq_group.is_encoder_decoder():
            self.block_manager.free_cross(seq_group)

447
    def has_unfinished_seqs(self) -> bool:
448
449
        return len(self.waiting) != 0 or len(self.running) != 0 or len(
            self.swapped) != 0
450

451
452
453
    def get_num_unfinished_seq_groups(self) -> int:
        return len(self.waiting) + len(self.running) + len(self.swapped)

Mor Zusman's avatar
Mor Zusman committed
454
455
456
457
458
459
    def get_and_reset_finished_requests_ids(self) -> List[str]:
        """Flushes the list of request ids of previously finished seq_groups."""
        finished_requests_ids = self._finished_requests_ids
        self._finished_requests_ids = list()
        return finished_requests_ids

460
    def _schedule_running(
461
462
463
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
464
        enable_chunking: bool = False,
465
    ) -> SchedulerRunningOutputs:
466
        """Schedule sequence groups that are running.
467

468
        Running queue should include decode and chunked prefill requests.
Woosuk Kwon's avatar
Woosuk Kwon committed
469

470
471
472
473
474
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any decodes are preempted.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any decodes are preempted.
475
476
477
478
479
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
    
480
        Returns:
481
            SchedulerRunningOutputs.
482
        """
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
        ret: SchedulerRunningOutputs = \
            self._scheduler_running_outputs_cache.get_object()
        ret.blocks_to_swap_out.clear()
        ret.blocks_to_copy.clear()
        ret.decode_seq_groups.clear()
        ret.prefill_seq_groups.clear()
        ret.preempted.clear()
        ret.swapped_out.clear()

        ret.num_lookahead_slots = self._get_num_lookahead_slots(
            is_prefill=False)

        ret.decode_seq_groups_list.clear()
        ret.prefill_seq_groups_list.clear()

498
        # Blocks that need to be swapped or copied before model execution.
499
500
        blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out
        blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy
Woosuk Kwon's avatar
Woosuk Kwon committed
501

502
503
504
505
506
        decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups
        prefill_seq_groups: List[
            ScheduledSequenceGroup] = ret.prefill_seq_groups
        preempted: List[SequenceGroup] = ret.preempted
        swapped_out: List[SequenceGroup] = ret.swapped_out
Woosuk Kwon's avatar
Woosuk Kwon committed
507
508
509

        # NOTE(woosuk): Preemption happens only when there is no available slot
        # to keep all the sequence groups in the RUNNING state.
510
511
512

        running_queue = self.running

513
514
        while running_queue:
            seq_group = running_queue[0]
515
516
517
            num_running_tokens = self._get_num_new_tokens(
                seq_group, SequenceStatus.RUNNING, enable_chunking, budget)

518
519
            if num_running_tokens == 0:
                break
520
521

            running_queue.popleft()
522
            while not self._can_append_slots(seq_group):
523
524
                budget.subtract_num_batched_tokens(seq_group.request_id,
                                                   num_running_tokens)
525
                num_running_seqs = seq_group.get_max_num_running_seqs()
526
527
                budget.subtract_num_seqs(seq_group.request_id,
                                         num_running_seqs)
528
529
530

                if (curr_loras is not None and seq_group.lora_int_id > 0
                        and seq_group.lora_int_id in curr_loras):
531
                    curr_loras.remove(seq_group.lora_int_id)
532
533

                if running_queue:
534
                    # Preempt the lowest-priority sequence groups.
535
536
537
538
539
540
541
                    victim_seq_group = running_queue.pop()
                    preempted_mode = self._preempt(victim_seq_group,
                                                   blocks_to_swap_out)
                    if preempted_mode == PreemptionMode.RECOMPUTE:
                        preempted.append(victim_seq_group)
                    else:
                        swapped_out.append(victim_seq_group)
542
543
544
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
545
546
547
548
549
550
                    preempted_mode = self._preempt(seq_group,
                                                   blocks_to_swap_out)
                    if preempted_mode == PreemptionMode.RECOMPUTE:
                        preempted.append(seq_group)
                    else:
                        swapped_out.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
551
552
                    break
            else:
553
                self._append_slots(seq_group, blocks_to_copy)
554
                is_prefill = seq_group.is_prefill()
555
556
557
558

                scheduled_seq_group: ScheduledSequenceGroup = \
                    self._scheduled_seq_group_cache.get_object()
                scheduled_seq_group.seq_group = seq_group
559
                if is_prefill:
560
561
562
                    scheduled_seq_group.token_chunk_size = num_running_tokens
                    prefill_seq_groups.append(scheduled_seq_group)
                    ret.prefill_seq_groups_list.append(seq_group)
563
                else:
564
565
566
567
                    scheduled_seq_group.token_chunk_size = 1
                    decode_seq_groups.append(scheduled_seq_group)
                    ret.decode_seq_groups_list.append(seq_group)

568
569
                budget.add_num_batched_tokens(seq_group.request_id,
                                              num_running_tokens)
570
571
572
573
574
575
576
                # OPTIMIZATION:  Note that get_max_num_running_seqs is
                # expensive. For the default scheduling chase where
                # enable_chunking is False, num_seqs are updated before running
                # this method, so we don't have to update it again here.
                if enable_chunking:
                    num_running_seqs = seq_group.get_max_num_running_seqs()
                    budget.add_num_seqs(seq_group.request_id, num_running_seqs)
577
578
579
                if curr_loras is not None and seq_group.lora_int_id > 0:
                    curr_loras.add(seq_group.lora_int_id)

580
581
582
583
        self._scheduler_running_outputs_cache.reset()
        self._scheduled_seq_group_cache.reset()

        return ret
584

585
586
587
588
    def _schedule_swapped(
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
589
        enable_chunking: bool = False,
590
    ) -> SchedulerSwappedInOutputs:
591
        """Schedule sequence groups that are swapped out.
592

593
594
595
        It schedules swapped requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.
596

597
598
599
600
601
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are swapped in.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are swapped in.
602
603
604
605
606
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.

607
608
609
610
        Returns:
            SchedulerSwappedInOutputs.
        """
        # Blocks that need to be swapped or copied before model execution.
611
        blocks_to_swap_in: List[Tuple[int, int]] = []
612
        blocks_to_copy: List[Tuple[int, int]] = []
613
614
        decode_seq_groups: List[ScheduledSequenceGroup] = []
        prefill_seq_groups: List[ScheduledSequenceGroup] = []
615
        infeasible_seq_groups: List[SequenceGroup] = []
616

617
618
        swapped_queue = self.swapped

619
        leftover_swapped: Deque[SequenceGroup] = deque()
620
621
622
623
        while swapped_queue:
            seq_group = swapped_queue[0]

            # If the sequence group cannot be swapped in, stop.
624
625
626
            is_prefill = seq_group.is_prefill()
            alloc_status = self.block_manager.can_swap_in(
                seq_group, self._get_num_lookahead_slots(is_prefill))
627
            if alloc_status == AllocStatus.LATER:
628
                break
629
630
631
632
633
634
635
636
637
638
            elif alloc_status == AllocStatus.NEVER:
                logger.warning(
                    "Failing the request %s because there's not enough kv "
                    "cache blocks to run the entire sequence.",
                    seq_group.request_id)
                for seq in seq_group.get_seqs():
                    seq.status = SequenceStatus.FINISHED_IGNORED
                infeasible_seq_groups.append(seq_group)
                swapped_queue.popleft()
                continue
639
640
641
642

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
643
644
645
                assert curr_loras is not None
                assert self.lora_config is not None
                if (lora_int_id > 0 and (lora_int_id not in curr_loras)
646
647
648
649
650
651
652
653
654
655
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_swapped.appendleft(seq_group)
                    swapped_queue.popleft()
                    continue

            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
            num_new_seqs = seq_group.get_max_num_running_seqs()
656
657
658
            num_new_tokens = self._get_num_new_tokens(seq_group,
                                                      SequenceStatus.SWAPPED,
                                                      enable_chunking, budget)
659

660
661
662
            if (num_new_tokens == 0
                    or not budget.can_schedule(num_new_tokens=num_new_tokens,
                                               num_new_seqs=num_new_seqs)):
663
664
665
666
667
668
669
                break

            if lora_int_id > 0 and curr_loras is not None:
                curr_loras.add(lora_int_id)
            swapped_queue.popleft()
            self._swap_in(seq_group, blocks_to_swap_in)
            self._append_slots(seq_group, blocks_to_copy)
670
671
672
673
674
675
676
677
678
679
            is_prefill = seq_group.is_prefill()
            if is_prefill:
                prefill_seq_groups.append(
                    ScheduledSequenceGroup(seq_group,
                                           token_chunk_size=num_new_tokens))
            else:
                decode_seq_groups.append(
                    ScheduledSequenceGroup(seq_group, token_chunk_size=1))
            budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
680
681
682

        swapped_queue.extendleft(leftover_swapped)

683
        return SchedulerSwappedInOutputs(
684
685
            decode_seq_groups=decode_seq_groups,
            prefill_seq_groups=prefill_seq_groups,
686
687
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_copy=blocks_to_copy,
688
            num_lookahead_slots=self._get_num_lookahead_slots(
689
690
691
                is_prefill=False),
            infeasible_seq_groups=infeasible_seq_groups,
        )
692

693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
    def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
        if self.scheduler_config.chunked_prefill_enabled:
            prompt_limit = self.scheduler_config.max_model_len
        else:
            prompt_limit = min(self.scheduler_config.max_model_len,
                               self.scheduler_config.max_num_batched_tokens)

        # Model is fine tuned with long context. Return the fine tuned max_len.
        if (seq_group.lora_request
                and seq_group.lora_request.long_lora_max_len):
            assert prompt_limit <= seq_group.lora_request.long_lora_max_len
            return seq_group.lora_request.long_lora_max_len
        else:
            return prompt_limit

708
709
710
711
    def _schedule_prefills(
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
712
        enable_chunking: bool = False,
713
    ) -> SchedulerPrefillOutputs:
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
        """Schedule sequence groups that are in prefill stage.

        Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
        as a new prefill (that starts from beginning -> most recently generated
        tokens).

        It schedules waiting requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.

        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are scheduled.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are scheduled.
729
730
731
732
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
733
734

        Returns:
735
            SchedulerPrefillOutputs.
736
737
738
        """
        ignored_seq_groups: List[SequenceGroup] = []
        seq_groups: List[SequenceGroup] = []
739
740

        waiting_queue = self.waiting
741

742
        leftover_waiting_sequences: Deque[SequenceGroup] = deque()
743
744
745
746
747
748
749
        while self._passed_delay(time.time()) and waiting_queue:
            seq_group = waiting_queue[0]

            waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
            assert len(waiting_seqs) == 1, (
                "Waiting sequence group should have only one prompt "
                "sequence.")
750
751
752
753
754
755
756
            num_new_tokens = self._get_num_new_tokens(seq_group,
                                                      SequenceStatus.WAITING,
                                                      enable_chunking, budget)
            if not enable_chunking:
                num_prompt_tokens = waiting_seqs[0].get_len()
                assert num_new_tokens == num_prompt_tokens

757
758
            prompt_limit = self._get_prompt_limit(seq_group)
            if num_new_tokens > prompt_limit:
759
                logger.warning(
760
                    "Input prompt (%d tokens) is too long"
761
                    " and exceeds limit of %d", num_new_tokens, prompt_limit)
762
763
764
765
766
767
768
769
770
771
772
773
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

            # If the sequence group cannot be allocated, stop.
            can_allocate = self.block_manager.can_allocate(seq_group)
            if can_allocate == AllocStatus.LATER:
                break
            elif can_allocate == AllocStatus.NEVER:
                logger.warning(
774
775
776
                    "Input prompt (%d tokens) is too long"
                    " and exceeds the capacity of block_manager",
                    num_new_tokens)
777
778
779
780
781
782
783
784
785
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
786
787
                assert curr_loras is not None
                assert self.lora_config is not None
788
789
790
791
792
793
794
795
796
797
                if (self.lora_enabled and lora_int_id > 0
                        and lora_int_id not in curr_loras
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_waiting_sequences.appendleft(seq_group)
                    waiting_queue.popleft()
                    continue

            num_new_seqs = seq_group.get_max_num_running_seqs()
798
799
800
            if (num_new_tokens == 0
                    or not budget.can_schedule(num_new_tokens=num_new_tokens,
                                               num_new_seqs=num_new_seqs)):
801
802
803
804
805
806
                break

            # Can schedule this request.
            if curr_loras is not None and lora_int_id > 0:
                curr_loras.add(lora_int_id)
            waiting_queue.popleft()
807
            self._allocate_and_set_running(seq_group)
808
809
            seq_groups.append(
                ScheduledSequenceGroup(seq_group=seq_group,
810
811
812
                                       token_chunk_size=num_new_tokens))
            budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
813
814
815
816
817
818

        # Queue requests that couldn't be scheduled.
        waiting_queue.extendleft(leftover_waiting_sequences)
        if len(seq_groups) > 0:
            self.prev_prompt = True

819
        return SchedulerPrefillOutputs(
820
821
822
823
            seq_groups=seq_groups,
            ignored_seq_groups=ignored_seq_groups,
            num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))

824
825
    def _schedule_default(self) -> SchedulerOutputs:
        """Schedule queued requests.
826
        
827
        The current policy is designed to optimize the throughput. First,
828
829
830
831
832
833
834
835
836
        it batches as many prefill requests as possible. And it schedules
        decodes. If there's a pressure on GPU memory, decode requests can
        be swapped or preempted.
        """
        # Include running requests to the budget.
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
837
838
839
840
841
        # Make sure we include num running seqs before scheduling prefill,
        # so that we don't schedule beyond max_num_seqs for prefill.
        for seq_group in self.running:
            budget.add_num_seqs(seq_group.request_id,
                                seq_group.get_max_num_running_seqs())
842
        curr_loras = set(
843
844
            seq_group.lora_int_id for seq_group in self.running
            if seq_group.lora_int_id > 0) if self.lora_enabled else None
845

846
847
848
        prefills = SchedulerPrefillOutputs.create_empty()
        running_scheduled = SchedulerRunningOutputs.create_empty()
        swapped_in = SchedulerSwappedInOutputs.create_empty()
849
850
851

        # If any requests are swapped, prioritized swapped requests.
        if not self.swapped:
852
853
854
            prefills = self._schedule_prefills(budget,
                                               curr_loras,
                                               enable_chunking=False)
855
856

        # Don't schedule decodes if prefills are scheduled.
857
858
        # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
        # only contains decode requests, not chunked prefills.
859
        if len(prefills.seq_groups) == 0:
860
861
862
            running_scheduled = self._schedule_running(budget,
                                                       curr_loras,
                                                       enable_chunking=False)
863

864
865
            # If any sequence group is preempted, do not swap in any sequence
            # group. because it means there's no slot for new running requests.
866
867
            if len(running_scheduled.preempted) + len(
                    running_scheduled.swapped_out) == 0:
868
                swapped_in = self._schedule_swapped(budget, curr_loras)
869
870
871
872
873
874

        assert (budget.num_batched_tokens <=
                self.scheduler_config.max_num_batched_tokens)
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
875
        self.waiting.extendleft(running_scheduled.preempted)
876
        # Update new running requests.
877
878
879
880
881
882
883
884
885
        if len(prefills.seq_groups) > 0:
            self.running.extend([s.seq_group for s in prefills.seq_groups])

        self.running.extend(running_scheduled.decode_seq_groups_list)

        if len(swapped_in.decode_seq_groups) > 0:
            self.running.extend(
                [s.seq_group for s in swapped_in.decode_seq_groups])

886
        # Update swapped requests.
887
        self.swapped.extend(running_scheduled.swapped_out)
888
889
        preempted = (len(running_scheduled.preempted) +
                     len(running_scheduled.swapped_out))
890

891
892
893
894
        # There should be no prefill from running queue because this policy
        # doesn't allow chunked prefills.
        assert len(running_scheduled.prefill_seq_groups) == 0
        assert len(swapped_in.prefill_seq_groups) == 0
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910

        # Merge lists
        num_prefill_groups = len(prefills.seq_groups)
        if num_prefill_groups > 0:
            scheduled_seq_groups = prefills.seq_groups
            scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
        else:
            scheduled_seq_groups = running_scheduled.decode_seq_groups
        scheduled_seq_groups.extend(swapped_in.decode_seq_groups)

        blocks_to_copy = running_scheduled.blocks_to_copy
        blocks_to_copy.extend(swapped_in.blocks_to_copy)

        ignored_seq_groups = prefills.ignored_seq_groups
        ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)

911
        return SchedulerOutputs(
912
913
            scheduled_seq_groups=scheduled_seq_groups,
            num_prefill_groups=num_prefill_groups,
914
915
            num_batched_tokens=budget.num_batched_tokens,
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
916
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
917
918
            blocks_to_copy=blocks_to_copy,
            ignored_seq_groups=ignored_seq_groups,
919
            num_lookahead_slots=running_scheduled.num_lookahead_slots,
920
            running_queue_size=len(self.running),
921
            preempted=preempted,
922
923
        )

924
    def _schedule_chunked_prefill(self) -> SchedulerOutputs:
925
926
927
928
929
930
931
932
933
934
        """Schedule queued requests.
        
        Chunked prefill allows to chunk prefill requests, batch them together
        with decode requests. This policy 1. schedule as many decoding requests
        as possible. 2. schedule chunked prefill requests that are not
        finished. 3. schedule swapped request. 4. schedule new prefill
        requests.

        The policy can sustain the high GPU utilization because it can put
        prefill and decodes requests to the same batch, while it improves
935
        inter token latency because decodes requests don't need to be blocked
936
937
938
939
940
941
        by prefill requests.
        """
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
942
        curr_loras: Set[int] = set()
943

944
945
        prefills = SchedulerPrefillOutputs.create_empty()
        swapped_in = SchedulerSwappedInOutputs.create_empty()
946
947

        # Decoding should be always scheduled first by fcfs.
948
949
950
        running_scheduled = self._schedule_running(budget,
                                                   curr_loras,
                                                   enable_chunking=True)
951
952
953
954
955

        # Schedule swapped out requests.
        # If preemption happens, it means we don't have space for swap-in.
        if len(running_scheduled.preempted) + len(
                running_scheduled.swapped_out) == 0:
956
            swapped_in = self._schedule_swapped(budget, curr_loras)
957
958

        # Schedule new prefills.
959
960
961
        prefills = self._schedule_prefills(budget,
                                           curr_loras,
                                           enable_chunking=True)
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983

        assert (budget.num_batched_tokens <=
                self.scheduler_config.max_num_batched_tokens)
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
        self.waiting.extendleft(running_scheduled.preempted)
        # Update new running requests.
        self.running.extend([s.seq_group for s in prefills.seq_groups])
        self.running.extend(
            [s.seq_group for s in running_scheduled.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in running_scheduled.prefill_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.prefill_seq_groups])
        # Update swapped requests.
        self.swapped.extend(running_scheduled.swapped_out)
        return SchedulerOutputs(
            scheduled_seq_groups=(prefills.seq_groups +
                                  running_scheduled.prefill_seq_groups +
984
985
986
                                  swapped_in.prefill_seq_groups +
                                  running_scheduled.decode_seq_groups +
                                  swapped_in.decode_seq_groups),
987
988
989
990
991
992
            num_prefill_groups=(len(prefills.seq_groups) +
                                len(swapped_in.prefill_seq_groups) +
                                len(running_scheduled.prefill_seq_groups)),
            num_batched_tokens=budget.num_batched_tokens,
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
993
994
            blocks_to_copy=running_scheduled.blocks_to_copy +
            swapped_in.blocks_to_copy,
995
996
            ignored_seq_groups=prefills.ignored_seq_groups +
            swapped_in.infeasible_seq_groups,
997
            num_lookahead_slots=running_scheduled.num_lookahead_slots,
998
            running_queue_size=len(self.running),
999
1000
            preempted=(len(running_scheduled.preempted) +
                       len(running_scheduled.swapped_out)),
1001
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
1002

1003
1004
1005
1006
1007
1008
1009
    def _schedule(self) -> SchedulerOutputs:
        """Schedule queued requests."""
        if self.scheduler_config.chunked_prefill_enabled:
            return self._schedule_chunked_prefill()
        else:
            return self._schedule_default()

1010
1011
1012
1013
    def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
        """Determine whether or not we have enough space in the KV cache to
        continue generation of the sequence group.
        """
1014
1015
1016
1017
1018
1019
1020
        # It is True only for testing case to trigger artificial preemption.
        if (self.enable_artificial_preemption
                and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
                and self.artificial_preempt_cnt > 0):
            self.artificial_preempt_cnt -= 1
            return False

1021
1022
1023
1024
1025
1026
1027
1028
        # Appending slots only occurs in decoding.
        is_prefill = False

        return self.block_manager.can_append_slots(
            seq_group=seq_group,
            num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
1029
    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
1030
1031
1032
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
Woosuk Kwon's avatar
Woosuk Kwon committed
1033
        scheduler_outputs = self._schedule()
1034
        now = time.time()
1035
        scheduler_start_time = time.perf_counter()
1036

1037
1038
1039
        if not self.cache_config.enable_prefix_caching:
            common_computed_block_nums = []

1040
        # Create input data structures.
1041
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
1042
1043
        for i, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
1044
1045
            seq_group = scheduled_seq_group.seq_group
            token_chunk_size = scheduled_seq_group.token_chunk_size
1046
1047
            seq_group.maybe_set_first_scheduled_time(now)

1048
1049
1050
1051
            seq_group_metadata = self._seq_group_metadata_cache.get_object()
            seq_group_metadata.seq_data.clear()
            seq_group_metadata.block_tables.clear()

1052
            # seq_id -> SequenceData
1053
            seq_data: Dict[int, SequenceData] = seq_group_metadata.seq_data
1054
            # seq_id -> physical block numbers
1055
1056
            block_tables: Dict[int,
                               List[int]] = seq_group_metadata.block_tables
1057

1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
            if seq_group.is_encoder_decoder():
                # Encoder associated with SequenceGroup
                encoder_seq_data = seq_group.get_encoder_seq().data
                # Block table for cross-attention
                # Also managed at SequenceGroup level
                cross_block_table = self.block_manager.get_cross_block_table(
                    seq_group)
            else:
                encoder_seq_data = None
                cross_block_table = None

1069
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
1070
                seq_id = seq.seq_id
1071
                seq_data[seq_id] = seq.data
1072
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
1073
                self.block_manager.access_all_blocks_in_seq(seq, now)
1074

1075
1076
1077
1078
            if self.cache_config.enable_prefix_caching:
                common_computed_block_nums = (
                    self.block_manager.get_common_computed_block_ids(
                        seq_group.get_seqs(status=SequenceStatus.RUNNING)))
1079

1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
            do_sample = True
            if seq_group.is_prefill():
                seqs = seq_group.get_seqs()
                # Prefill has only 1 sequence.
                assert len(seqs) == 1
                # In the next iteration, all prompt tokens are not computed.
                # It means the prefill is chunked, and we don't need sampling.
                # NOTE: We use get_len instead of get_prompt_len because when
                # a sequence is preempted, prefill includes previous generated
                # output tokens.
                if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
                        seqs[0].data.get_len()):
                    do_sample = False

1094
1095
            # It assumes the scheduled_seq_groups is ordered by
            # prefill < decoding.
1096
            is_prompt = seq_group.is_prefill()
1097
1098

            seq_group_metadata.__init__(
1099
                request_id=seq_group.request_id,
1100
                is_prompt=is_prompt,
1101
                seq_data=seq_data,
1102
                sampling_params=seq_group.sampling_params,
1103
                block_tables=block_tables,
1104
                do_sample=do_sample,
1105
                pooling_params=seq_group.pooling_params,
1106
                token_chunk_size=token_chunk_size,
1107
                lora_request=seq_group.lora_request,
1108
                computed_block_nums=common_computed_block_nums,
1109
1110
                encoder_seq_data=encoder_seq_data,
                cross_block_table=cross_block_table,
1111
1112
1113
1114
1115
                # `multi_modal_data` will only be present for the 1st comm
                # between engine and worker.
                # the subsequent comms can still use delta, but
                # `multi_modal_data` will be None.
                multi_modal_data=seq_group.multi_modal_data
1116
                if scheduler_outputs.num_prefill_groups > 0 else None,
1117
                prompt_adapter_request=seq_group.prompt_adapter_request,
1118
            )
1119
            seq_group_metadata_list.append(seq_group_metadata)
1120
1121
1122
1123
1124

        # Now that the batch has been created, we can assume all blocks in the
        # batch will have been computed before the next scheduling invocation.
        # This is because the engine assumes that a failure in model execution
        # will crash the vLLM instance / will not retry.
1125
1126
1127
        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            self.block_manager.mark_blocks_as_computed(
                scheduled_seq_group.seq_group)
1128

1129
1130
        self._seq_group_metadata_cache.reset()

1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
        scheduler_time = time.perf_counter() - scheduler_start_time
        # Add this to scheduler time to all the sequences that are currently
        # running. This will help estimate if the scheduler is a significant
        # component in the e2e latency.
        for seq_group in self.running:
            if seq_group is not None and seq_group.metrics is not None:
                if seq_group.metrics.scheduler_time is not None:
                    seq_group.metrics.scheduler_time += scheduler_time
                else:
                    seq_group.metrics.scheduler_time = scheduler_time

Woosuk Kwon's avatar
Woosuk Kwon committed
1142
        return seq_group_metadata_list, scheduler_outputs
1143

1144
1145
    def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
        self.block_manager.fork(parent_seq, child_seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
1146

1147
    def free_seq(self, seq: Sequence) -> None:
1148
        """Free a sequence from a block table."""
1149
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
1150

1151
    def free_finished_seq_groups(self) -> None:
1152
1153
1154
        remaining: Deque[SequenceGroup] = deque()
        for seq_group in self.running:
            if seq_group.is_finished():
1155
1156
                # Free cross-attention block table, if it exists
                self._free_seq_group_cross_attn_blocks(seq_group)
1157
1158
1159
1160
1161
1162
1163
                # Add the finished requests to the finished requests list.
                # This list will be used to update the Mamba cache in the
                # next step.
                self._finished_requests_ids.append(seq_group.request_id)
            else:
                remaining.append(seq_group)
        self.running = remaining
Woosuk Kwon's avatar
Woosuk Kwon committed
1164

1165
    def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
1166
        self.block_manager.allocate(seq_group)
1167
        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
1168
1169
            seq.status = SequenceStatus.RUNNING

1170
    def _append_slots(
1171
1172
        self,
        seq_group: SequenceGroup,
1173
        blocks_to_copy: List[Tuple[int, int]],
1174
    ) -> None:
1175
1176
1177
1178
1179
        """Appends new slots to the sequences in the given sequence group.

        Args:
            seq_group (SequenceGroup): The sequence group containing the
                sequences to append slots to.
1180
1181
1182
1183
1184
            blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
                ints, the first int is the source block index, and the second
                int is the destination block index. This list is updated with
                the new source and destination block indices for the appended
                slots.
1185
1186
1187
        """
        num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)

1188
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
1189
            cows = self.block_manager.append_slots(seq, num_lookahead_slots)
1190
1191
            if len(cows) > 0:
                blocks_to_copy.extend(cows)
1192
1193
1194
1195

    def _preempt(
        self,
        seq_group: SequenceGroup,
1196
        blocks_to_swap_out: List[Tuple[int, int]],
1197
        preemption_mode: Optional[PreemptionMode] = None,
1198
    ) -> PreemptionMode:
1199
1200
1201
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
1202
1203
        # (e.g., beam search), recomputation is not currently supported. In
        # such a case, we use swapping instead.
1204
1205
1206
1207
1208
1209
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
1210
        if self.user_specified_preemption_mode is None:
1211
            if seq_group.get_max_num_running_seqs() == 1:
1212
1213
1214
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
1215

1216
1217
1218
1219
1220
        elif self.user_specified_preemption_mode == "swap":
            preemption_mode = PreemptionMode.SWAP
        else:
            preemption_mode = PreemptionMode.RECOMPUTE

1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
        if self.num_cumulative_preemption % 50 == 0:
            logger.warning(
                "Sequence group %s is preempted by %s mode because there is "
                "not enough KV cache space. This can affect the end-to-end "
                "performance. Increase gpu_memory_utilization or "
                "tensor_parallel_size to provide more KV cache memory. "
                "total_num_cumulative_preemption=%d", seq_group.request_id,
                preemption_mode, self.num_cumulative_preemption + 1)
        self.num_cumulative_preemption += 1

1231
1232
1233
1234
1235
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
1236
            raise AssertionError("Invalid preemption mode.")
1237
        return preemption_mode
1238
1239
1240
1241
1242
1243
1244
1245
1246

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
1247
1248
            self.free_seq(seq)
            seq.reset_state_for_recompute()
1249
1250
1251
1252

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
1253
        blocks_to_swap_out: List[Tuple[int, int]],
1254
1255
1256
1257
1258
1259
    ) -> None:
        self._swap_out(seq_group, blocks_to_swap_out)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
1260
        blocks_to_swap_in: List[Tuple[int, int]],
1261
1262
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
1263
        blocks_to_swap_in.extend(mapping)
1264
1265
1266
1267
1268
1269
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
1270
        blocks_to_swap_out: List[Tuple[int, int]],
1271
    ) -> None:
1272
1273
1274
1275
1276
1277
        if not self.block_manager.can_swap_out(seq_group):
            # FIXME(woosuk): Abort the sequence group instead of aborting the
            # entire engine.
            raise RuntimeError(
                "Aborted due to the lack of CPU swap space. Please increase "
                "the swap space to avoid this error.")
1278
        mapping = self.block_manager.swap_out(seq_group)
1279
        blocks_to_swap_out.extend(mapping)
1280
1281
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED
1282

1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
    def _passed_delay(self, now: float) -> bool:
        if self.prev_prompt:
            self.last_prompt_latency = now - self.prev_time
        self.prev_time, self.prev_prompt = now, False
        # Delay scheduling prompts to let waiting queue fill up
        if self.scheduler_config.delay_factor > 0 and self.waiting:
            earliest_arrival_time = min(
                [e.metrics.arrival_time for e in self.waiting])
            passed_delay = (
                (now - earliest_arrival_time) >
                (self.scheduler_config.delay_factor * self.last_prompt_latency)
                or not self.running)
        else:
            passed_delay = True
        return passed_delay
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310

    def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
        """The number of slots to allocate per sequence per step, beyond known
        token ids. Speculative decoding uses these slots to store KV activations
        of tokens which may or may not be accepted.

        Speculative decoding does not yet support prefill, so we do not perform
        lookahead allocation for prefill.
        """
        if is_prefill:
            return 0

        return self.scheduler_config.num_lookahead_slots
1311
1312
1313

    def _get_num_new_tokens(self, seq_group: SequenceGroup,
                            status: SequenceStatus, enable_chunking: bool,
1314
                            budget: SchedulingBudget) -> int:
1315
1316
1317
1318
1319
1320
1321
        """Get the next new tokens to compute for a given sequence group
            that's in a given `status`.

        The API could chunk the number of tokens to compute based on `budget`
        if `enable_chunking` is True. If a sequence group has multiple
        sequences (e.g., running beam search), it means it is in decoding
        phase, so chunking doesn't happen.
1322
1323

        Returns 0 if the new token cannot be computed due to token budget.
1324
1325
1326
1327
1328
        """
        num_new_tokens = 0
        seqs = seq_group.get_seqs(status=status)
        for seq in seqs:
            num_new_tokens += seq.get_num_new_tokens()
1329
        assert num_new_tokens > 0
1330
1331
1332
1333
1334
1335
1336
        # Chunk if a running request cannot fit in.
        # If number of seq > 1, it means it is doing beam search in a
        # decode phase. Do not chunk in that case.
        if enable_chunking and len(seqs) == 1:
            num_new_tokens = min(num_new_tokens,
                                 budget.remaining_token_budget())
        return num_new_tokens