scheduler.py 91.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
6
import os
import random
7
import time
8
from collections import deque
9
from dataclasses import dataclass, field
10
11
12
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
15
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
18
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
19
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
20
21
22
                           SequenceGroupBase, SequenceGroupMetadata,
                           SequenceGroupMetadataDelta, SequenceStage,
                           SequenceStatus)
23
from vllm.utils import Device, PyObjectCache
Woosuk Kwon's avatar
Woosuk Kwon committed
24

Woosuk Kwon's avatar
Woosuk Kwon committed
25
logger = init_logger(__name__)
26

27
28
29
30
31
32
33
# Test-only. If configured, decode is preempted with
# ARTIFICIAL_PREEMPTION_PROB% probability.
ENABLE_ARTIFICIAL_PREEMPT = bool(
    os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False))  # noqa
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500

Woosuk Kwon's avatar
Woosuk Kwon committed
34

35
36
37
38
39
40
41
42
43
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
44

45
46
47
48
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


49
50
@dataclass
class SchedulingBudget:
51
52
53
54
55
56
57
58
59
    """The available slots for scheduling.

    TODO(sang): Right now, the budget is request_id-aware meaning it can ignore
    budget update from the same request_id. It is because in normal scheduling
    path, we update RUNNING num_seqs ahead of time, meaning it could be
    updated more than once when scheduling RUNNING requests. Since this won't
    happen if we only have chunked prefill scheduling, we can remove this
    feature from the API when chunked prefill is enabled by default.
    """
60

61
62
    token_budget: int
    max_num_seqs: int
63
64
    _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
    _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
65
66
67
    # Number of cached tokens in the batch.
    _num_cached_tokens: int = 0
    # Number of actual non-cached tokens in the batch.
68
69
    _num_batched_tokens: int = 0
    _num_curr_seqs: int = 0
70
71

    def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
72
73
74
        # We allow num_new_tokens to be 0 when the entire sequence has
        # been cached.
        assert num_new_tokens >= 0
75
        assert num_new_seqs != 0
76
77
78
        return (self.num_batched_tokens + num_new_tokens <= self.token_budget
                and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)

79
80
81
    def remaining_token_budget(self):
        return self.token_budget - self.num_batched_tokens

82
83
84
85
    def add_num_batched_tokens(self,
                               req_id: str,
                               num_batched_tokens: int,
                               num_cached_tokens: int = 0):
86
        if req_id in self._request_ids_num_batched_tokens:
87
            return
88
89
        assert num_cached_tokens >= 0
        assert num_batched_tokens >= 0
90

91
        self._request_ids_num_batched_tokens.add(req_id)
92
        self._num_batched_tokens += num_batched_tokens
93
        self._num_cached_tokens += num_cached_tokens
94
95
96

    def subtract_num_batched_tokens(self, req_id: str,
                                    num_batched_tokens: int):
97
98
        if req_id in self._request_ids_num_batched_tokens:
            self._request_ids_num_batched_tokens.remove(req_id)
99
100
101
            self._num_batched_tokens -= num_batched_tokens

    def add_num_seqs(self, req_id: str, num_curr_seqs: int):
102
        if req_id in self._request_ids_num_curr_seqs:
103
104
            return

105
        self._request_ids_num_curr_seqs.add(req_id)
106
107
108
        self._num_curr_seqs += num_curr_seqs

    def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
109
110
        if req_id in self._request_ids_num_curr_seqs:
            self._request_ids_num_curr_seqs.remove(req_id)
111
112
113
114
115
116
117
118
119
120
            self._num_curr_seqs -= num_curr_seqs

    @property
    def num_batched_tokens(self):
        return self._num_batched_tokens

    @property
    def num_curr_seqs(self):
        return self._num_curr_seqs

121
122
123
124
    @property
    def num_cached_tokens(self):
        return self._num_cached_tokens

125

126
127
128
129
130
131
132
133
134
135
@dataclass
class ScheduledSequenceGroup:
    # A sequence group that's scheduled.
    seq_group: SequenceGroup
    # The total chunk size (number of tokens) to process for next iteration.
    # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
    # chunked, it can be smaller than that.
    token_chunk_size: int


136
@dataclass
137
class SchedulerOutputs:
138
    """The scheduling decision made from a scheduler."""
139

140
    # Scheduled sequence groups.
141
    scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
142
143
144
145
    # Number of prefill groups scheduled.
    num_prefill_groups: int
    # Total number of batched tokens.
    num_batched_tokens: int
146
147
148
149
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]]
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]]
150
151
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]]
152
153
154
155
    # Sequence groups that are going to be ignored.
    ignored_seq_groups: List[SequenceGroup]
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int
156
157
    # The number of requests in the running queue
    running_queue_size: int
158
    preempted: int
159
160

    def __post_init__(self):
161
        # Swap in and swap out should never happen at the same time.
162
        assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
163

164
        self.num_loras: int = len(self.lora_requests)
165
166
167
        if self.num_loras > 0:
            self._sort_by_lora_ids()

168
169
        self.num_prompt_adapters: int = len(self.prompt_adapter_requests)

170
    def is_empty(self) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
171
172
173
        # NOTE: We do not consider the ignored sequence groups.
        return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
                and not self.blocks_to_swap_out and not self.blocks_to_copy)
174

175
    def _sort_by_lora_ids(self):
176
177
178
179
180
181
182
183
184
185
186
187
        assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups)

        def key_fn(group: ScheduledSequenceGroup):
            key = (group.seq_group.lora_int_id, group.seq_group.request_id)
            if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups):
                # Sort sequence groups so that all prefills come before all
                # decodes as required by chunked prefill.
                return (not group.seq_group.is_prefill(), *key)
            return key

        self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
                                           key=key_fn)
188
189
190

    @property
    def lora_requests(self) -> Set[LoRARequest]:
191
192
193
194
195
        return {
            g.seq_group.lora_request
            for g in self.scheduled_seq_groups
            if g.seq_group.lora_request is not None
        }
196

197
198
199
200
201
202
203
204
    @property
    def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
        return {
            g.seq_group.prompt_adapter_request
            for g in self.scheduled_seq_groups
            if g.seq_group.prompt_adapter_request is not None
        }

205

206
@dataclass
207
208
209
210
211
212
class SchedulerRunningOutputs:
    """The requests that are scheduled from a running queue.

    Could contain prefill (prefill that's chunked) or decodes. If there's not
    enough memory, it can be preempted (for recompute) or swapped out.
    """
213

214
    # Selected sequences that are running and in a decoding phase.
215
    decode_seq_groups: List[ScheduledSequenceGroup]
216
217
    # Selected sequences that are running and in a prefill phase.
    # I.e., it means the prefill has been chunked.
218
    prefill_seq_groups: List[ScheduledSequenceGroup]
219
220
221
222
223
    # The preempted sequences.
    preempted: List[SequenceGroup]
    # Sequences that are swapped out.
    swapped_out: List[SequenceGroup]
    # The blocks to swap out.
224
    blocks_to_swap_out: List[Tuple[int, int]]
225
    # The blocks to copy.
226
    blocks_to_copy: List[Tuple[int, int]]
227
    # The number of slots for lookahead decoding.
228
229
    num_lookahead_slots: int

230
231
232
233
    # Optimization for fast-access to seq_group lists
    decode_seq_groups_list: List[SequenceGroup]
    prefill_seq_groups_list: List[SequenceGroup]

234
    @classmethod
235
236
237
238
    def create_empty(cls) -> "SchedulerRunningOutputs":
        return SchedulerRunningOutputs(
            decode_seq_groups=[],
            prefill_seq_groups=[],
239
240
            preempted=[],
            swapped_out=[],
241
            blocks_to_swap_out=[],
242
            blocks_to_copy=[],
243
            num_lookahead_slots=0,
244
245
            decode_seq_groups_list=[],
            prefill_seq_groups_list=[],
246
247
248
249
250
        )


@dataclass
class SchedulerSwappedInOutputs:
251
252
253
254
    """The requests that are scheduled from a swap queue.

    Could contain prefill (prefill that's chunked) or decodes.
    """
255

256
257
    # Selected sequences that are going to be swapped in and is in a
    # decoding phase.
258
    decode_seq_groups: List[ScheduledSequenceGroup]
259
260
    # Selected sequences that are going to be swapped in and in a prefill
    # phase. I.e., it means the prefill has been chunked.
261
    prefill_seq_groups: List[ScheduledSequenceGroup]
262
    # The blocks to swap in.
263
    blocks_to_swap_in: List[Tuple[int, int]]
264
    # The blocks to copy.
265
    blocks_to_copy: List[Tuple[int, int]]
266
    # The number of slots for lookahead decoding.
267
    num_lookahead_slots: int
268
269
    # Infeasible sequence groups.
    infeasible_seq_groups: List[SequenceGroup]
270
271
272
273

    @classmethod
    def create_empty(cls) -> "SchedulerSwappedInOutputs":
        return SchedulerSwappedInOutputs(
274
275
            decode_seq_groups=[],
            prefill_seq_groups=[],
276
            blocks_to_swap_in=[],
277
            blocks_to_copy=[],
278
            num_lookahead_slots=0,
279
            infeasible_seq_groups=[],
280
281
282
283
284
        )


@dataclass
class SchedulerPrefillOutputs:
285
286
287
288
289
    """The requests that are scheduled from a waiting queue.

    Could contain a fresh prefill requests or preempted requests that need
    to be recomputed from scratch.
    """
290

291
    # Selected sequences for prefill.
292
    seq_groups: List[ScheduledSequenceGroup]
293
294
295
296
297
298
299
300
301
302
303
304
305
    # Ignored sequence groups.
    ignored_seq_groups: List[SequenceGroup]
    num_lookahead_slots: int

    @classmethod
    def create_empty(cls) -> "SchedulerPrefillOutputs":
        return SchedulerPrefillOutputs(
            seq_groups=[],
            ignored_seq_groups=[],
            num_lookahead_slots=0,
        )


306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def seq_group_metadata_builder():
    return SequenceGroupMetadata(request_id="",
                                 is_prompt=False,
                                 seq_data={},
                                 sampling_params=None,
                                 block_tables={})


def scheduler_running_outputs_builder():
    return SchedulerRunningOutputs(decode_seq_groups=[],
                                   prefill_seq_groups=[],
                                   preempted=[],
                                   swapped_out=[],
                                   blocks_to_swap_out=[],
                                   blocks_to_copy=[],
                                   num_lookahead_slots=0,
                                   prefill_seq_groups_list=[],
                                   decode_seq_groups_list=[])


def scheduled_seq_group_builder():
327
    return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup),
328
329
                                  token_chunk_size=0)
    # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
330
331


332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
@dataclass
class PartialPrefillMetadata:
    """Holds information about the partial prefills that are currently running
    during a single iteration of the Scheduler.
    When chunked prefill is enabled, we allow a certain number of seqs to be
    partially prefilled during each iteration. Having multiple partial prefills
    in flight allows us to minimize TTFT and avoid decode starvation in cases
    where a single sequence group with a very large prompt blocks the queue for
    too many iterations.
    The number of long prefill requests is limited so that smaller
    requests may jump the queue in front of them and get to the decode
    phase faster.
    """

    # A minimum bound on the total number of prefills to be scheduled during
    # this iteration
    schedulable_prefills: int

    # The number of long prefill requests currently running
    long_prefills: int

    scheduler_config: SchedulerConfig

    def can_schedule(self, seq_group: SequenceGroup) -> bool:
        """When concurrent partial prefills are enabled,
        we limit the number of long requests and only accept
        shorter requests from the queue while running them
        concurrently"""
        return not (seq_group.first_seq.get_num_new_tokens()
                    > self.scheduler_config.long_prefill_token_threshold
                    and self.long_prefills
                    >= self.scheduler_config.max_long_partial_prefills
                    and self.scheduler_config.max_num_partial_prefills > 1)

    def maybe_increment_partial_prefills(self,
                                         seq_group: SequenceGroup) -> None:
        # When a new prefill is scheduled, we need to know if it is a
        # long request
        if (seq_group.first_seq.get_num_new_tokens()
                > self.scheduler_config.long_prefill_token_threshold):
            self.long_prefills += 1

    @classmethod
    def from_queues(
        cls,
        running: Deque[SequenceGroup],
        waiting: Deque[SequenceGroup],
        scheduler_config: SchedulerConfig,
    ) -> "PartialPrefillMetadata":
        """Create a PartialPrefillMetadata object from the current state of
        the scheduler's queues.
        This accounts for the currently running prefill requests, and peeks into
        the waiting queue to see if there are more prefills to potentially be
        scheduled during this iteration."""
        prefills = 0
        long_prefills = 0

        waiting_long_prefills = 0

        for sg in running:
            if sg.first_seq.data.stage == SequenceStage.PREFILL:
                prefills += 1
                if (sg.first_seq.get_num_new_tokens()
                        > scheduler_config.long_prefill_token_threshold):
                    long_prefills += 1

        for sg in waiting:
            # Don't bother looping through the rest of the queue if we know
            # there are already at
            # least max_partial_prefills requests to fill
            if prefills >= scheduler_config.max_num_partial_prefills:
                break

            # Don't count long requests from the waiting queue if we aren't
            # going to schedule them anyway
            if (sg.first_seq.get_num_new_tokens()
                    > scheduler_config.long_prefill_token_threshold):
                if (long_prefills + waiting_long_prefills
                        >= scheduler_config.max_long_partial_prefills):
                    continue
                waiting_long_prefills += 1
            prefills += 1

        # NB: long_prefills and waiting_long_prefills are tracked separately.
        # We don't account for the waiting requests here because we need to use
        # this metadata to track how many have actually been scheduled.
        return PartialPrefillMetadata(
            schedulable_prefills=min(
                prefills, scheduler_config.max_num_partial_prefills),
            long_prefills=long_prefills,
            scheduler_config=scheduler_config,
        )


Woosuk Kwon's avatar
Woosuk Kwon committed
426
427
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
428
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
429
        self,
430
431
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
432
        lora_config: Optional[LoRAConfig],
433
        pipeline_parallel_size: int = 1,
434
        output_proc_callback: Optional[Callable] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
435
    ) -> None:
436
437
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
438
439
440
441
        # Note for LoRA scheduling: the current policy is extremely
        # simple and NOT fair. It can lead to starvation of some
        # LoRAs. This should be improved in the future.
        self.lora_config = lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
442

443
        version = "selfattn"
444
        if (self.scheduler_config.runner_type == "pooling"
445
446
                or self.cache_config.is_attention_free):
            version = "placeholder"
447

448
        BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
449
            version)
450

451
452
453
454
455
456
457
458
        num_gpu_blocks = cache_config.num_gpu_blocks
        if num_gpu_blocks:
            num_gpu_blocks //= pipeline_parallel_size

        num_cpu_blocks = cache_config.num_cpu_blocks
        if num_cpu_blocks:
            num_cpu_blocks //= pipeline_parallel_size

Woosuk Kwon's avatar
Woosuk Kwon committed
459
        # Create the block space manager.
460
        self.block_manager = BlockSpaceManagerImpl(
461
            block_size=self.cache_config.block_size,
462
463
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
464
            sliding_window=self.cache_config.sliding_window,
465
466
            enable_caching=self.cache_config.enable_prefix_caching,
        )
467

468
        # Sequence groups in the WAITING state.
469
        # Contain new prefill or preempted requests.
470
        self.waiting: Deque[SequenceGroup] = deque()
471
        # Sequence groups in the RUNNING state.
472
        # Contain decode requests.
473
        self.running: Deque[SequenceGroup] = deque()
474
        # Sequence groups in the SWAPPED state.
475
        # Contain decode requests that are swapped out.
476
        self.swapped: Deque[SequenceGroup] = deque()
Mor Zusman's avatar
Mor Zusman committed
477
478
479
        # Sequence groups finished requests ids since last step iteration.
        # It lets the model know that any state associated with these requests
        # can and must be released after the current step.
480
        # This is used to evict the finished requests from the Mamba cache.
Mor Zusman's avatar
Mor Zusman committed
481
        self._finished_requests_ids: List[str] = list()
482
483
484
485
486
487
        # Time at previous scheduling step
        self.prev_time = 0.0
        # Did we schedule a prompt at previous step?
        self.prev_prompt = False
        # Latency of the last prompt step
        self.last_prompt_latency = 0.0
488
489
        # preemption mode, RECOMPUTE or SWAP
        self.user_specified_preemption_mode = scheduler_config.preemption_mode
490

491
492
493
494
495
496
        # The following field is test-only. It is used to inject artificial
        # preemption.
        self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
        self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
                                       if self.enable_artificial_preemption
                                       else 0)
497
        self.num_cumulative_preemption: int = 0
498

499
        # Used to cache python objects
500
501
502
503
504
505
506
507
        self._seq_group_metadata_cache: List[PyObjectCache] = []
        self._scheduler_running_outputs_cache: List[PyObjectCache] = []
        self._scheduled_seq_group_cache: List[PyObjectCache] = []

        # For async output processing, we need to swap cache buffers between
        # iterations. I.e. since the output processing is lagged one step,
        # we cannot reuse the cached objects immediately when the schedule()
        # is called again, but only when schedule() is called the second time.
508
509
        self.output_proc_callback = output_proc_callback
        self.use_async_output_proc = self.output_proc_callback is not None
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
        self.num_cache_iters = 2 if self.use_async_output_proc else 1

        self.cache_id = 0
        for i in range(self.num_cache_iters):
            self._seq_group_metadata_cache.append(
                PyObjectCache(seq_group_metadata_builder))
            self._scheduler_running_outputs_cache.append(
                PyObjectCache(scheduler_running_outputs_builder))
            self._scheduled_seq_group_cache.append(
                PyObjectCache(scheduled_seq_group_builder))

        # For async postprocessor, the extra decode run cannot be done
        # when the request reaches max_model_len. In this case, the request
        # will be stopped during schedule() call and added to this stop list
        # for processing and deallocation by the free_finished_seq_groups()
        self._async_stopped: List[SequenceGroup] = []

527
528
529
530
531
532
533
534
535
536
537
538
        # List with the chunk sizes to hand out to each sequence depending
        # on how many partial prefills are running. This is slightly faster than
        # running an integer division every time a prefill is scheduled.
        # This splits the budget evenly among all prefills.
        self.partial_prefill_budget_lookup_list = [0] * (
            self.scheduler_config.max_num_partial_prefills + 1)
        self.partial_prefill_budget_lookup_list[0] = (
            scheduler_config.max_num_batched_tokens)
        for i in range(1, self.scheduler_config.max_num_partial_prefills + 1):
            self.partial_prefill_budget_lookup_list[i] = (
                scheduler_config.max_num_batched_tokens // i)

539
540
541
    @property
    def next_cache_id(self):
        return (self.cache_id + 1) % self.num_cache_iters
542

543
544
545
546
    @property
    def lora_enabled(self) -> bool:
        return bool(self.lora_config)

547
548
549
550
551
    @property
    def num_decoding_tokens_per_seq(self) -> int:
        """The number of new tokens."""
        return 1

552
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
553
        # Add sequence groups to the waiting queue.
554
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
555

556
557
558
559
560
561
562
563
564
565
    def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
        # Add sequence groups to the running queue.
        # Only for testing purposes.
        self.running.append(seq_group)

    def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
        # Add sequence groups to the swapped queue.
        # Only for testing purposes.
        self.swapped.append(seq_group)

566
567
568
569
570
    def abort_seq_group(
        self,
        request_id: Union[str, Iterable[str]],
        seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
    ) -> None:
571
572
573
574
575
576
577
578
579
580
581
        """Aborts a sequence group with the given ID.

        Check if the sequence group with the given ID
            is present in any of the state queue.
        If present, remove the sequence group from the state queue.
            Also, if any of the sequences in the sequence group is not finished,
                free the sequence with status `FINISHED_ABORTED`.
        Otherwise, do nothing.

        Args:
            request_id: The ID(s) of the sequence group to abort.
582
            seq_id_to_seq_group: helper for groups with n>1
583
        """
Antoni Baum's avatar
Antoni Baum committed
584
585
586
        if isinstance(request_id, str):
            request_id = (request_id, )
        request_ids = set(request_id)
587
        seq_id_to_seq_group = seq_id_to_seq_group or {}
588
        for state_queue in [self.waiting, self.running, self.swapped]:
ljss's avatar
ljss committed
589
            aborted_groups: List[SequenceGroup] = []
590
            for seq_group in state_queue:
591
592
593
594
595
596
597
598
599
                # When n>1, seq_group.request_id looks like
                # foo_parallel_sample_0, while request_ids is just foo, and we
                # should resolve it as real_request_id to match.
                if seq_group.request_id in seq_id_to_seq_group:
                    real_request_id = seq_id_to_seq_group[
                        seq_group.request_id].group_id
                else:
                    real_request_id = seq_group.request_id
                if real_request_id in request_ids:
600
601
                    # Appending aborted group into pending list.
                    aborted_groups.append(seq_group)
602
603
604
                    # We can't remove real_request_id in request_ids here,
                    # because there may be other seq groups sharing the same
                    # real_request_id
605
606
607
            for aborted_group in aborted_groups:
                # Remove the sequence group from the state queue.
                state_queue.remove(aborted_group)
608
                # Remove the aborted request from the Mamba cache.
609
                self._finished_requests_ids.append(aborted_group.request_id)
ljss's avatar
ljss committed
610
                for seq in aborted_group.get_seqs():
611
612
613
614
                    if seq.is_finished():
                        continue
                    seq.status = SequenceStatus.FINISHED_ABORTED
                    self.free_seq(seq)
615
616
                if aborted_group.request_id in seq_id_to_seq_group:
                    del seq_id_to_seq_group[aborted_group.request_id]
617

618
619
620
621
622
623
624
625
626
627
628
629
630
                self._free_seq_group_cross_attn_blocks(aborted_group)

    def _free_seq_group_cross_attn_blocks(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        """
        Free a sequence group from a cross-attention block table.
        Has no effect on decoder-only models.
        """
        if seq_group.is_encoder_decoder():
            self.block_manager.free_cross(seq_group)

631
    def has_unfinished_seqs(self) -> bool:
632
633
        return (len(self.waiting) != 0 or len(self.running) != 0
                or len(self.swapped) != 0)
634

635
636
637
    def get_prefix_cache_hit_rate(self, device: Device) -> float:
        return self.block_manager.get_prefix_cache_hit_rate(device)

638
639
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.block_manager.reset_prefix_cache(device)
640

641
642
643
    def get_num_unfinished_seq_groups(self) -> int:
        return len(self.waiting) + len(self.running) + len(self.swapped)

Mor Zusman's avatar
Mor Zusman committed
644
645
646
647
648
649
    def get_and_reset_finished_requests_ids(self) -> List[str]:
        """Flushes the list of request ids of previously finished seq_groups."""
        finished_requests_ids = self._finished_requests_ids
        self._finished_requests_ids = list()
        return finished_requests_ids

650
    def _schedule_running(
651
652
653
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
654
        enable_chunking: bool = False,
655
        partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
656
    ) -> SchedulerRunningOutputs:
657
        """Schedule sequence groups that are running.
658

659
        Running queue should include decode and chunked prefill requests.
Woosuk Kwon's avatar
Woosuk Kwon committed
660

661
662
663
664
665
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any decodes are preempted.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any decodes are preempted.
666
667
668
669
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
670
671
672
            partial_prefill_metadata: information about the partial prefills
            that are currently running

673
        Returns:
674
            SchedulerRunningOutputs.
675
        """
676
677
        ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[
            self.cache_id].get_object()
678
679
680
681
682
683
684
685
        ret.blocks_to_swap_out.clear()
        ret.blocks_to_copy.clear()
        ret.decode_seq_groups.clear()
        ret.prefill_seq_groups.clear()
        ret.preempted.clear()
        ret.swapped_out.clear()

        ret.num_lookahead_slots = self._get_num_lookahead_slots(
686
            is_prefill=False, enable_chunking=enable_chunking)
687
688
689
690

        ret.decode_seq_groups_list.clear()
        ret.prefill_seq_groups_list.clear()

691
        # Blocks that need to be swapped or copied before model execution.
692
693
        blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out
        blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy
Woosuk Kwon's avatar
Woosuk Kwon committed
694

695
696
697
698
699
        decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups
        prefill_seq_groups: List[
            ScheduledSequenceGroup] = ret.prefill_seq_groups
        preempted: List[SequenceGroup] = ret.preempted
        swapped_out: List[SequenceGroup] = ret.swapped_out
Woosuk Kwon's avatar
Woosuk Kwon committed
700

701
702
        running_queue = self.running
        assert len(self._async_stopped) == 0
703
704
        while running_queue:
            seq_group = running_queue[0]
705
706
707
708
709
710
711
            # We discard the cached tokens info here because we don't need it
            # for running sequence:
            #   1. If a sequence is running with chunked prefill, the cached
            #      tokens info was already used for the first prefill.
            #   2. If a sequence is running with non-chunked prefill, then
            #      there it's a decoding sequence, and the cached tokens info is
            #      irrelevant.
712
            num_uncached_new_tokens, _ = \
713
                self._get_num_new_uncached_and_cached_tokens(
714
715
716
717
718
719
                seq_group,
                SequenceStatus.RUNNING,
                enable_chunking,
                budget,
                partial_prefill_metadata,
            )
720
721

            num_running_tokens = num_uncached_new_tokens
722
            if num_running_tokens == 0:
723
                # No budget => Stop
724
                break
725
726

            running_queue.popleft()
727
728
729
730
731

            # With async postprocessor, an extra decode run is done
            # to process the final tokens. The check below avoids this extra
            # decode run when the model max len is reached, in order to avoid
            # a memory overflow.
732
733
            if (self.use_async_output_proc and seq_group.seqs[0].get_len()
                    > self.scheduler_config.max_model_len):
734
735
736
                self._async_stopped.append(seq_group)
                continue

737
738
            # NOTE(woosuk): Preemption happens only when there is no available
            # slot to keep all the sequence groups in the RUNNING state.
739
            while not self._can_append_slots(seq_group, enable_chunking):
740
741
                budget.subtract_num_batched_tokens(seq_group.request_id,
                                                   num_running_tokens)
742
                num_running_seqs = seq_group.get_max_num_running_seqs()
743
744
                budget.subtract_num_seqs(seq_group.request_id,
                                         num_running_seqs)
745
746
747

                if (curr_loras is not None and seq_group.lora_int_id > 0
                        and seq_group.lora_int_id in curr_loras):
748
                    curr_loras.remove(seq_group.lora_int_id)
749

750
751
                # Determine victim sequence
                cont_loop = True
752
                if running_queue:
753
                    # Preempt the lowest-priority sequence group.
754
                    victim_seq_group = running_queue.pop()
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
                else:
                    # No other sequence group can be preempted.
                    # Preempt the current sequence group.
                    # Note: This is also where we stop this loop
                    # (since there is nothing else to preempt)
                    victim_seq_group = seq_group
                    cont_loop = False

                # With async postprocessor, before preempting a sequence
                # we need to ensure it has no pending async postprocessor
                do_preempt = True
                if self.use_async_output_proc:
                    assert self.output_proc_callback is not None
                    self.output_proc_callback(
                        request_id=victim_seq_group.request_id)

                    # It may be that the async pending "victim_seq_group"
                    # becomes finished, in which case we simply free it.
                    if victim_seq_group.is_finished():
                        self._free_finished_seq_group(victim_seq_group)
                        do_preempt = False

                # Do preemption
                if do_preempt:
779
780
781
782
783
784
                    preempted_mode = self._preempt(victim_seq_group,
                                                   blocks_to_swap_out)
                    if preempted_mode == PreemptionMode.RECOMPUTE:
                        preempted.append(victim_seq_group)
                    else:
                        swapped_out.append(victim_seq_group)
785
786

                if not cont_loop:
Woosuk Kwon's avatar
Woosuk Kwon committed
787
788
                    break
            else:
789
                self._append_slots(seq_group, blocks_to_copy, enable_chunking)
790
                is_prefill = seq_group.is_prefill()
791

792
793
794
                scheduled_seq_group: ScheduledSequenceGroup = (
                    self._scheduled_seq_group_cache[
                        self.cache_id].get_object())
795
                scheduled_seq_group.seq_group = seq_group
796
                if is_prefill:
797
798
799
                    scheduled_seq_group.token_chunk_size = num_running_tokens
                    prefill_seq_groups.append(scheduled_seq_group)
                    ret.prefill_seq_groups_list.append(seq_group)
800
                else:
801
802
803
804
                    scheduled_seq_group.token_chunk_size = 1
                    decode_seq_groups.append(scheduled_seq_group)
                    ret.decode_seq_groups_list.append(seq_group)

805
806
                budget.add_num_batched_tokens(seq_group.request_id,
                                              num_running_tokens)
807
808
809
810
811
812
813
                # OPTIMIZATION:  Note that get_max_num_running_seqs is
                # expensive. For the default scheduling chase where
                # enable_chunking is False, num_seqs are updated before running
                # this method, so we don't have to update it again here.
                if enable_chunking:
                    num_running_seqs = seq_group.get_max_num_running_seqs()
                    budget.add_num_seqs(seq_group.request_id, num_running_seqs)
814
815
816
                if curr_loras is not None and seq_group.lora_int_id > 0:
                    curr_loras.add(seq_group.lora_int_id)

817
818
        self._scheduler_running_outputs_cache[self.next_cache_id].reset()
        self._scheduled_seq_group_cache[self.next_cache_id].reset()
819
820

        return ret
821

822
823
824
825
    def _schedule_swapped(
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
826
        enable_chunking: bool = False,
827
    ) -> SchedulerSwappedInOutputs:
828
        """Schedule sequence groups that are swapped out.
829

830
831
832
        It schedules swapped requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.
833

834
835
836
837
838
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are swapped in.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are swapped in.
839
840
841
842
843
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.

844
845
846
847
        Returns:
            SchedulerSwappedInOutputs.
        """
        # Blocks that need to be swapped or copied before model execution.
848
        blocks_to_swap_in: List[Tuple[int, int]] = []
849
        blocks_to_copy: List[Tuple[int, int]] = []
850
851
        decode_seq_groups: List[ScheduledSequenceGroup] = []
        prefill_seq_groups: List[ScheduledSequenceGroup] = []
852
        infeasible_seq_groups: List[SequenceGroup] = []
853

854
855
        swapped_queue = self.swapped

856
        leftover_swapped: Deque[SequenceGroup] = deque()
857
858
859
860
        while swapped_queue:
            seq_group = swapped_queue[0]

            # If the sequence group cannot be swapped in, stop.
861
862
            is_prefill = seq_group.is_prefill()
            alloc_status = self.block_manager.can_swap_in(
863
864
                seq_group,
                self._get_num_lookahead_slots(is_prefill, enable_chunking))
865
            if alloc_status == AllocStatus.LATER:
866
                break
867
868
869
870
            elif alloc_status == AllocStatus.NEVER:
                logger.warning(
                    "Failing the request %s because there's not enough kv "
                    "cache blocks to run the entire sequence.",
871
872
                    seq_group.request_id,
                )
873
874
875
876
877
                for seq in seq_group.get_seqs():
                    seq.status = SequenceStatus.FINISHED_IGNORED
                infeasible_seq_groups.append(seq_group)
                swapped_queue.popleft()
                continue
878
879
880
881

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
882
883
884
                assert curr_loras is not None
                assert self.lora_config is not None
                if (lora_int_id > 0 and (lora_int_id not in curr_loras)
885
886
887
888
889
890
891
892
893
894
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_swapped.appendleft(seq_group)
                    swapped_queue.popleft()
                    continue

            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
            num_new_seqs = seq_group.get_max_num_running_seqs()
895
896
897
898
899
900
901
902
903
            num_new_tokens_uncached, num_new_tokens_cached = (
                self._get_num_new_uncached_and_cached_tokens(
                    seq_group, SequenceStatus.SWAPPED, enable_chunking,
                    budget))

            if num_new_tokens_uncached == 0 or not budget.can_schedule(
                    num_new_tokens=num_new_tokens_uncached,
                    num_new_seqs=num_new_seqs,
            ):
904
905
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.SWAPPED)
906
907
908
909
910
911
                break

            if lora_int_id > 0 and curr_loras is not None:
                curr_loras.add(lora_int_id)
            swapped_queue.popleft()
            self._swap_in(seq_group, blocks_to_swap_in)
912
            self._append_slots(seq_group, blocks_to_copy, enable_chunking)
913
914
            if is_prefill:
                prefill_seq_groups.append(
915
916
917
918
919
                    ScheduledSequenceGroup(
                        seq_group,
                        token_chunk_size=num_new_tokens_uncached +
                        num_new_tokens_cached,
                    ))
920
921
922
            else:
                decode_seq_groups.append(
                    ScheduledSequenceGroup(seq_group, token_chunk_size=1))
923
924
925
926
927
            budget.add_num_batched_tokens(
                seq_group.request_id,
                num_batched_tokens=num_new_tokens_uncached,
                num_cached_tokens=num_new_tokens_cached,
            )
928
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
929
930
931

        swapped_queue.extendleft(leftover_swapped)

932
        return SchedulerSwappedInOutputs(
933
934
            decode_seq_groups=decode_seq_groups,
            prefill_seq_groups=prefill_seq_groups,
935
936
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_copy=blocks_to_copy,
937
            num_lookahead_slots=self._get_num_lookahead_slots(
938
                is_prefill=False, enable_chunking=enable_chunking),
939
940
            infeasible_seq_groups=infeasible_seq_groups,
        )
941

942
    def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
943
944
        if (self.scheduler_config.chunked_prefill_enabled
                and not self.scheduler_config.is_multi_step):
945
946
            prompt_limit = self.scheduler_config.max_model_len
        else:
947
948
949
950
            prompt_limit = min(
                self.scheduler_config.max_model_len,
                self.scheduler_config.max_num_batched_tokens,
            )
951
952

        # Model is fine tuned with long context. Return the fine tuned max_len.
953
        if seq_group.lora_request and seq_group.lora_request.long_lora_max_len:
954
955
956
957
958
            assert prompt_limit <= seq_group.lora_request.long_lora_max_len
            return seq_group.lora_request.long_lora_max_len
        else:
            return prompt_limit

959
960
    def _get_priority(self,
                      seq_group: SequenceGroup) -> Tuple[Optional[int], float]:
961
        """Get the priority of the sequence group.
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
        Highest preference to user-defined priority, followed by arrival time.
        Args:
            seq_group: The sequence group input.
        Returns:
            The priority of the sequence group.
        """
        return seq_group.priority, seq_group.arrival_time

    def _schedule_priority_preemption(
        self,
        budget: SchedulingBudget,
    ) -> int:
        """Sorts waiting and running queue. Also, force preempt requests
        from the running queue if their priority is lower.
        Priority-based preemption is used with the priority policy.
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are scheduled.
        Returns:
            A count of priority-based preemptions.
        """

        waiting_queue = self.waiting

        running_queue = deque(sorted(self.running, key=self._get_priority))

        blocks_to_swap_out: List[Tuple[int, int]] = []
        force_preemption_count = 0

        if waiting_queue:
            seq_group = waiting_queue.popleft()
            num_new_seqs = seq_group.get_max_num_running_seqs()
994
            num_new_tokens_uncached, _ = \
995
                self._get_num_new_uncached_and_cached_tokens(
996
                seq_group, SequenceStatus.WAITING, False, budget)
997

998
            # Only preempt if priority inversion exists
999
1000
            while running_queue and self._get_priority(
                    running_queue[-1]) > self._get_priority(seq_group):
1001
                # Only preempt if waiting sequence cannot be allocated
1002
                can_allocate = self.block_manager.can_allocate(seq_group)
1003
1004
1005
1006
1007
1008
                if (num_new_tokens_uncached > 0
                        and can_allocate == AllocStatus.OK
                        and budget.can_schedule(
                            num_new_tokens=num_new_tokens_uncached,
                            num_new_seqs=num_new_seqs,
                        )):
1009
1010
                    break

1011
                # Adjust budget to remove the victim sequence group
1012
                vseq_group = running_queue.pop()
1013
1014
1015
1016
1017
                num_running_tokens_uncached, _ = (
                    self._get_num_new_uncached_and_cached_tokens(
                        vseq_group, SequenceStatus.RUNNING, False, budget))
                budget.subtract_num_batched_tokens(
                    vseq_group.request_id, num_running_tokens_uncached)
1018
1019
1020
1021
                num_running_seqs = vseq_group.get_max_num_running_seqs()
                budget.subtract_num_seqs(vseq_group.request_id,
                                         num_running_seqs)

1022
                # Preempt out the victim sequence group
1023
                self._preempt(vseq_group, blocks_to_swap_out)
1024
1025
                waiting_queue.appendleft(vseq_group)
                force_preemption_count += 1
1026
            # Put the sequence back into the waiting queue
1027
1028
            waiting_queue.appendleft(seq_group)

1029
1030
1031
            self.remove_seq_from_computed_blocks_tracker(
                seq_group, SequenceStatus.WAITING)

1032
1033
1034
1035
1036
1037
        waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))

        self.waiting = waiting_queue
        self.running = running_queue
        return force_preemption_count

1038
1039
1040
1041
    def _schedule_prefills(
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
1042
        enable_chunking: bool = False,
1043
        partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
1044
    ) -> SchedulerPrefillOutputs:
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
        """Schedule sequence groups that are in prefill stage.

        Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
        as a new prefill (that starts from beginning -> most recently generated
        tokens).

        It schedules waiting requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.

        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are scheduled.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are scheduled.
1060
1061
1062
1063
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
1064
1065
            partial_prefill_metadata: information about the partial prefills
                that are currently running
1066
1067

        Returns:
1068
            SchedulerPrefillOutputs.
1069
        """
1070
1071
1072
1073
1074
1075
1076
1077
        if budget.remaining_token_budget() == 0:
            # Do nothing: Can't add any more prefill anyway
            return SchedulerPrefillOutputs(
                seq_groups=[],
                ignored_seq_groups=[],
                num_lookahead_slots=self._get_num_lookahead_slots(
                    is_prefill=True, enable_chunking=enable_chunking),
            )
1078
        ignored_seq_groups: List[SequenceGroup] = []
1079
        seq_groups: List[ScheduledSequenceGroup] = []
1080
        using_prompt_embeds: bool = False
1081
1082

        waiting_queue = self.waiting
1083

1084
        leftover_waiting_sequences: Deque[SequenceGroup] = deque()
1085
1086
1087
1088
1089
1090
1091
        while self._passed_delay(time.time()) and waiting_queue:
            seq_group = waiting_queue[0]

            waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
            assert len(waiting_seqs) == 1, (
                "Waiting sequence group should have only one prompt "
                "sequence.")
1092
1093
1094
1095
1096
            if (partial_prefill_metadata is not None
                    and not partial_prefill_metadata.can_schedule(seq_group)):
                leftover_waiting_sequences.appendleft(seq_group)
                waiting_queue.popleft()
                continue
1097
1098
            num_new_tokens_uncached, num_new_tokens_cached = (
                self._get_num_new_uncached_and_cached_tokens(
1099
1100
1101
1102
1103
1104
                    seq_group,
                    SequenceStatus.WAITING,
                    enable_chunking,
                    budget,
                    partial_prefill_metadata=partial_prefill_metadata,
                ))
1105
1106
            num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached

1107
1108
1109
1110
            if not enable_chunking:
                num_prompt_tokens = waiting_seqs[0].get_len()
                assert num_new_tokens == num_prompt_tokens

1111
1112
            prompt_limit = self._get_prompt_limit(seq_group)
            if num_new_tokens > prompt_limit:
1113
                logger.warning(
1114
                    "Input prompt (%d tokens) is too long"
1115
1116
1117
1118
                    " and exceeds limit of %d",
                    num_new_tokens,
                    prompt_limit,
                )
1119
1120
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
1121
1122
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.FINISHED_IGNORED)
1123
1124
1125
1126
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

1127
1128
1129
1130
1131
            num_lookahead_slots: int = 0
            if self.scheduler_config.is_multi_step and enable_chunking:
                num_lookahead_slots = self._get_num_lookahead_slots(
                    True, enable_chunking)

1132
            # If the sequence group cannot be allocated, stop.
1133
1134
            can_allocate = self.block_manager.can_allocate(
                seq_group, num_lookahead_slots=num_lookahead_slots)
1135
            if can_allocate == AllocStatus.LATER:
1136
1137
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.WAITING)
1138
1139
1140
                break
            elif can_allocate == AllocStatus.NEVER:
                logger.warning(
1141
1142
                    "Input prompt (%d tokens) + lookahead slots (%d) is "
                    "too long and exceeds the capacity of block_manager",
1143
1144
1145
                    num_new_tokens,
                    num_lookahead_slots,
                )
1146
1147
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
1148
1149
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.FINISHED_IGNORED)
1150
1151
1152
1153
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

1154
1155
1156
1157
1158
            # We cannot mix sequence groups that use prompt embeds and
            # those that do not.
            if len(seq_groups) == 0:
                using_prompt_embeds = seq_group.uses_prompt_embeds()
            if using_prompt_embeds != seq_group.uses_prompt_embeds():
1159
1160
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.WAITING)
1161
1162
1163
1164
                leftover_waiting_sequences.appendleft(seq_group)
                waiting_queue.popleft()
                continue

1165
1166
1167
            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
1168
1169
                assert curr_loras is not None
                assert self.lora_config is not None
1170
1171
1172
1173
1174
                if (self.lora_enabled and lora_int_id > 0
                        and lora_int_id not in curr_loras
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
1175
1176
                    self.remove_seq_from_computed_blocks_tracker(
                        seq_group, SequenceStatus.WAITING)
1177
1178
1179
1180
                    leftover_waiting_sequences.appendleft(seq_group)
                    waiting_queue.popleft()
                    continue

1181
1182
            if (budget.num_batched_tokens
                    >= self.scheduler_config.max_num_batched_tokens):
1183
1184
1185
                # We've reached the budget limit - since there might be
                # continuous prefills in the running queue, we should break
                # to avoid scheduling any new prefills.
1186
1187
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.WAITING)
1188
1189
                break

1190
            num_new_seqs = seq_group.get_max_num_running_seqs()
1191
1192
1193
1194
            if num_new_tokens_uncached == 0 or not budget.can_schedule(
                    num_new_tokens=num_new_tokens_uncached,
                    num_new_seqs=num_new_seqs,
            ):
1195
1196
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.WAITING)
1197
1198
1199
1200
1201
1202
                break

            # Can schedule this request.
            if curr_loras is not None and lora_int_id > 0:
                curr_loras.add(lora_int_id)
            waiting_queue.popleft()
1203
            self._allocate_and_set_running(seq_group)
1204

1205
1206
1207
1208
            if partial_prefill_metadata is not None:
                partial_prefill_metadata.maybe_increment_partial_prefills(
                    seq_group)

1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
            if enable_chunking and self.scheduler_config.is_multi_step:
                blocks_to_copy: List[Tuple[int, int]] = []
                # init_multi_step_from_lookahead_slots happens in append_slots
                self._append_slots(seq_group, blocks_to_copy, enable_chunking)
                # This assert will trip when a copy-on-write happens. This is
                # not a concern as the very first sequence-group block
                # allocation happens above. Still, we have the assert to
                # catch any edge-cases.
                assert not blocks_to_copy
            else:
                seq_group.init_multi_step_from_lookahead_slots(
                    num_lookahead_slots,
                    num_scheduler_steps=self.scheduler_config.
                    num_scheduler_steps,
                    is_multi_step=self.scheduler_config.is_multi_step,
1224
1225
                    enable_chunking=enable_chunking,
                )
1226

1227
1228
            seq_groups.append(
                ScheduledSequenceGroup(seq_group=seq_group,
1229
                                       token_chunk_size=num_new_tokens))
1230
1231
1232
1233
1234
            budget.add_num_batched_tokens(
                seq_group.request_id,
                num_batched_tokens=num_new_tokens_uncached,
                num_cached_tokens=num_new_tokens_cached,
            )
1235
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
1236
1237
1238
1239
1240
1241

        # Queue requests that couldn't be scheduled.
        waiting_queue.extendleft(leftover_waiting_sequences)
        if len(seq_groups) > 0:
            self.prev_prompt = True

1242
        return SchedulerPrefillOutputs(
1243
1244
            seq_groups=seq_groups,
            ignored_seq_groups=ignored_seq_groups,
1245
            num_lookahead_slots=self._get_num_lookahead_slots(
1246
1247
                is_prefill=True, enable_chunking=enable_chunking),
        )
1248

1249
1250
    def _schedule_default(self) -> SchedulerOutputs:
        """Schedule queued requests.
1251

1252
        The current policy is designed to optimize the throughput. First,
1253
1254
1255
1256
1257
1258
1259
1260
1261
        it batches as many prefill requests as possible. And it schedules
        decodes. If there's a pressure on GPU memory, decode requests can
        be swapped or preempted.
        """
        # Include running requests to the budget.
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
1262
1263
1264
1265
1266
        # Make sure we include num running seqs before scheduling prefill,
        # so that we don't schedule beyond max_num_seqs for prefill.
        for seq_group in self.running:
            budget.add_num_seqs(seq_group.request_id,
                                seq_group.get_max_num_running_seqs())
1267
        curr_loras = (set(
1268
            seq_group.lora_int_id for seq_group in self.running
1269
            if seq_group.lora_int_id > 0) if self.lora_enabled else None)
1270

1271
1272
1273
        prefills = SchedulerPrefillOutputs.create_empty()
        running_scheduled = SchedulerRunningOutputs.create_empty()
        swapped_in = SchedulerSwappedInOutputs.create_empty()
1274
1275
1276

        # If any requests are swapped, prioritized swapped requests.
        if not self.swapped:
1277
1278
1279
            prefills = self._schedule_prefills(budget,
                                               curr_loras,
                                               enable_chunking=False)
1280

1281
1282
1283
1284
        if len(prefills.seq_groups
               ) == 0 and self.scheduler_config.policy == "priority":
            self._schedule_priority_preemption(budget)

1285
        # Don't schedule decodes if prefills are scheduled.
1286
1287
        # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
        # only contains decode requests, not chunked prefills.
1288
        if len(prefills.seq_groups) == 0:
1289
1290
1291
            running_scheduled = self._schedule_running(budget,
                                                       curr_loras,
                                                       enable_chunking=False)
1292

1293
1294
            # If any sequence group is preempted, do not swap in any sequence
            # group. because it means there's no slot for new running requests.
1295
1296
1297
1298
            if (len(running_scheduled.preempted) +
                    len(running_scheduled.swapped_out) == 0):
                swapped_in = \
                    self._schedule_swapped(budget, curr_loras)
1299

1300
1301
        assert (budget.num_batched_tokens
                <= self.scheduler_config.max_num_batched_tokens)
1302
1303
1304
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
1305
        self.waiting.extendleft(running_scheduled.preempted)
1306
        # Update new running requests.
1307
1308
1309
1310
1311
1312
1313
1314
1315
        if len(prefills.seq_groups) > 0:
            self.running.extend([s.seq_group for s in prefills.seq_groups])

        self.running.extend(running_scheduled.decode_seq_groups_list)

        if len(swapped_in.decode_seq_groups) > 0:
            self.running.extend(
                [s.seq_group for s in swapped_in.decode_seq_groups])

1316
        # Update swapped requests.
1317
        self.swapped.extend(running_scheduled.swapped_out)
1318
1319
        preempted = len(running_scheduled.preempted) + len(
            running_scheduled.swapped_out)
1320

1321
1322
1323
1324
        # There should be no prefill from running queue because this policy
        # doesn't allow chunked prefills.
        assert len(running_scheduled.prefill_seq_groups) == 0
        assert len(swapped_in.prefill_seq_groups) == 0
1325
1326
1327

        # Merge lists
        num_prefill_groups = len(prefills.seq_groups)
1328
        ignored_seq_groups_for_embeds = list[SequenceGroup]()
1329
1330
1331
        if num_prefill_groups > 0:
            scheduled_seq_groups = prefills.seq_groups
            scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
1332
            ignored_seq_groups_for_embeds.clear()
1333
1334
        else:
            scheduled_seq_groups = running_scheduled.decode_seq_groups
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
            if len(scheduled_seq_groups) > 0:
                using_prompt_embeds = scheduled_seq_groups[
                    0].seq_group.uses_prompt_embeds()
                ignored_seq_groups_for_embeds.clear()
                indices_ignored = list[int]()
                for i, schedule_seq_group in enumerate(scheduled_seq_groups):
                    if using_prompt_embeds !=\
                        schedule_seq_group.seq_group.uses_prompt_embeds():
                        ignored_seq_groups_for_embeds.append(
                            schedule_seq_group.seq_group)
                        indices_ignored.append(i)
                if len(ignored_seq_groups_for_embeds) > 0:
                    scheduled_seq_groups = [
                        group for i, group in enumerate(scheduled_seq_groups)
                        if i not in indices_ignored
                    ]
            else:
                ignored_seq_groups_for_embeds.clear()

1354
1355
1356
1357
1358
1359
        scheduled_seq_groups.extend(swapped_in.decode_seq_groups)

        blocks_to_copy = running_scheduled.blocks_to_copy
        blocks_to_copy.extend(swapped_in.blocks_to_copy)

        ignored_seq_groups = prefills.ignored_seq_groups
1360
        ignored_seq_groups.extend(ignored_seq_groups_for_embeds)
1361
1362
        ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)

1363
        return SchedulerOutputs(
1364
1365
            scheduled_seq_groups=scheduled_seq_groups,
            num_prefill_groups=num_prefill_groups,
1366
1367
            num_batched_tokens=budget.num_batched_tokens +
            budget.num_cached_tokens,
1368
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
1369
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
1370
1371
            blocks_to_copy=blocks_to_copy,
            ignored_seq_groups=ignored_seq_groups,
1372
            num_lookahead_slots=running_scheduled.num_lookahead_slots,
1373
            running_queue_size=len(self.running),
1374
            preempted=preempted,
1375
1376
        )

1377
    def _schedule_chunked_prefill(self) -> SchedulerOutputs:
1378
        """Schedule queued requests.
1379

1380
1381
1382
1383
1384
1385
1386
1387
        Chunked prefill allows to chunk prefill requests, batch them together
        with decode requests. This policy 1. schedule as many decoding requests
        as possible. 2. schedule chunked prefill requests that are not
        finished. 3. schedule swapped request. 4. schedule new prefill
        requests.

        The policy can sustain the high GPU utilization because it can put
        prefill and decodes requests to the same batch, while it improves
1388
        inter token latency because decodes requests don't need to be blocked
1389
1390
1391
1392
1393
1394
        by prefill requests.
        """
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
1395
        curr_loras: Set[int] = set()
1396

1397
1398
        prefills = SchedulerPrefillOutputs.create_empty()
        swapped_in = SchedulerSwappedInOutputs.create_empty()
1399

1400
1401
1402
1403
1404
1405
1406
        # Create partial prefill metadata
        partial_prefill_metadata = PartialPrefillMetadata.from_queues(
            running=self.running,
            waiting=self.waiting,
            scheduler_config=self.scheduler_config,
        )

1407
        # Decoding should be always scheduled first by fcfs.
1408
1409
1410
1411
1412
1413
        running_scheduled = self._schedule_running(
            budget,
            curr_loras,
            enable_chunking=True,
            partial_prefill_metadata=partial_prefill_metadata,
        )
1414
1415
1416
1417
1418

        # Schedule swapped out requests.
        # If preemption happens, it means we don't have space for swap-in.
        if len(running_scheduled.preempted) + len(
                running_scheduled.swapped_out) == 0:
1419
            swapped_in = self._schedule_swapped(budget, curr_loras)
1420

1421
1422
1423
1424
1425
1426
        prefills = self._schedule_prefills(
            budget,
            curr_loras,
            enable_chunking=True,
            partial_prefill_metadata=partial_prefill_metadata,
        )
1427

1428
1429
        assert (budget.num_batched_tokens
                <= self.scheduler_config.max_num_batched_tokens)
1430
1431
1432
1433
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
        self.waiting.extendleft(running_scheduled.preempted)
1434

1435
        # Update new running requests.
1436
1437
1438
        # By default, vLLM scheduler prioritizes prefills.
        # Once chunked prefill is enabled,
        # the policy is changed to prioritize decode requests.
1439
1440
1441
1442
        self.running.extend(
            [s.seq_group for s in swapped_in.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.prefill_seq_groups])
1443
1444
        self.running.extend(
            [s.seq_group for s in running_scheduled.decode_seq_groups])
1445
1446
1447
1448
1449
1450
        # Because multiple prefills may be running concurrently, we need to
        # make sure that prefills which are scheduled to finish are listed
        # before those that won't. This is so that on the next scheduling
        # iteration when they have transitioned to the decode stage, they are
        # properly prioritized over sequences that are still in the prefill
        # stage.
1451
        self.running.extend(
1452
1453
            self._order_finishing_prefills_first(
                running_scheduled.prefill_seq_groups))
1454
1455
        self.running.extend([s.seq_group for s in prefills.seq_groups])

1456
1457
        # Update swapped requests.
        self.swapped.extend(running_scheduled.swapped_out)
1458
        # Put prefills first due to Attention backend ordering assumption.
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
        scheduled_seq_groups = (prefills.seq_groups +
                                running_scheduled.prefill_seq_groups +
                                swapped_in.prefill_seq_groups +
                                running_scheduled.decode_seq_groups +
                                swapped_in.decode_seq_groups)
        num_prefill_groups = (len(prefills.seq_groups) +
                              len(swapped_in.prefill_seq_groups) +
                              len(running_scheduled.prefill_seq_groups))
        # If all prompts, then we set num_lookahead_slots to 0
        # this allows us to go through the `no_spec` path in
        # `spec_decode_worker.py`
1470
        all_prefills = len(scheduled_seq_groups) == num_prefill_groups
1471
1472
1473
1474
        num_lookahead_slots = (0 if
                               (all_prefills
                                and not self.scheduler_config.is_multi_step)
                               else running_scheduled.num_lookahead_slots)
1475
        return SchedulerOutputs(
1476
1477
            scheduled_seq_groups=scheduled_seq_groups,
            num_prefill_groups=num_prefill_groups,
1478
1479
            num_batched_tokens=budget.num_batched_tokens +
            budget.num_cached_tokens,
1480
1481
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
1482
1483
            blocks_to_copy=running_scheduled.blocks_to_copy +
            swapped_in.blocks_to_copy,
1484
1485
            ignored_seq_groups=prefills.ignored_seq_groups +
            swapped_in.infeasible_seq_groups,
1486
            num_lookahead_slots=num_lookahead_slots,
1487
            running_queue_size=len(self.running),
1488
1489
            preempted=(len(running_scheduled.preempted) +
                       len(running_scheduled.swapped_out)),
1490
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
1491

1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
    def _order_finishing_prefills_first(
        self, scheduled_prefill_seqs: List[ScheduledSequenceGroup]
    ) -> List[SequenceGroup]:
        """Returns a list of prefilling SequenceGroups where sequences that are
        scheduled to finish prefilling are listed first"""
        finishing = [
            s.seq_group for s in scheduled_prefill_seqs
            if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size
        ]
        not_finishing = [
            s.seq_group for s in scheduled_prefill_seqs
            if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size
        ]
        return finishing + not_finishing

1507
1508
1509
1510
1511
1512
1513
    def _schedule(self) -> SchedulerOutputs:
        """Schedule queued requests."""
        if self.scheduler_config.chunked_prefill_enabled:
            return self._schedule_chunked_prefill()
        else:
            return self._schedule_default()

1514
1515
    def _can_append_slots(self, seq_group: SequenceGroup,
                          enable_chunking: bool) -> bool:
1516
1517
1518
        """Determine whether or not we have enough space in the KV cache to
        continue generation of the sequence group.
        """
1519
1520
1521
1522
1523
1524
1525
        # It is True only for testing case to trigger artificial preemption.
        if (self.enable_artificial_preemption
                and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
                and self.artificial_preempt_cnt > 0):
            self.artificial_preempt_cnt -= 1
            return False

1526
1527
1528
1529
1530
1531
1532
1533
        is_prefill = seq_group.is_prefill()
        num_lookahead_slots = self._get_num_lookahead_slots(
            is_prefill, enable_chunking)

        if is_prefill and num_lookahead_slots > 0:
            # Appending prefill slots only happens multi-step and
            # chunked-prefill are enabled together.
            assert self.scheduler_config.is_multi_step and enable_chunking
1534
1535

        return self.block_manager.can_append_slots(
1536
            seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
1537

1538
    def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
1539
1540
1541
        # async_output_proc is allowed only when we have a single sequence
        # in the sequence group
        no_single_seq = seq_group.sampling_params is None or (
1542
            seq_group.sampling_params.n == 1)
1543
        return no_single_seq
1544
1545
1546
1547

    def schedule(
            self
    ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
1548
1549
1550
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
1551
        scheduler_start_time = time.perf_counter()
1552

1553
        scheduler_outputs: SchedulerOutputs = self._schedule()
1554
        now = time.time()
1555

1556
1557
1558
        if not self.cache_config.enable_prefix_caching:
            common_computed_block_nums = []

1559
        allow_async_output_proc: bool = self.use_async_output_proc
1560

1561
        # Create input data structures.
1562
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
1563
1564
        for i, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
1565
1566
            seq_group = scheduled_seq_group.seq_group
            token_chunk_size = scheduled_seq_group.token_chunk_size
1567
1568
            seq_group.maybe_set_first_scheduled_time(now)

1569
1570
1571
1572
1573
            seq_group_metadata = self._seq_group_metadata_cache[
                self.cache_id].get_object()
            seq_group_metadata.seq_data.clear()
            seq_group_metadata.block_tables.clear()

1574
            # seq_id -> SequenceData
1575
            seq_data: Dict[int, SequenceData] = {}
1576
            # seq_id -> physical block numbers
1577
            block_tables: Dict[int, List[int]] = {}
1578

1579
1580
            if seq_group.is_encoder_decoder():
                # Encoder associated with SequenceGroup
1581
1582
1583
                encoder_seq = seq_group.get_encoder_seq()
                assert encoder_seq is not None
                encoder_seq_data = encoder_seq.data
1584
1585
1586
1587
1588
1589
1590
1591
                # Block table for cross-attention
                # Also managed at SequenceGroup level
                cross_block_table = self.block_manager.get_cross_block_table(
                    seq_group)
            else:
                encoder_seq_data = None
                cross_block_table = None

1592
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
1593
                seq_id = seq.seq_id
1594
                seq_data[seq_id] = seq.data
1595
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
1596
                self.block_manager.access_all_blocks_in_seq(seq, now)
1597

1598
1599
1600
1601
            if self.cache_config.enable_prefix_caching:
                common_computed_block_nums = (
                    self.block_manager.get_common_computed_block_ids(
                        seq_group.get_seqs(status=SequenceStatus.RUNNING)))
1602

1603
            do_sample = True
1604
1605
1606
1607
1608
            is_prompt = seq_group.is_prefill()
            # We should send the metadata to workers when the first prefill
            # is sent. Subsequent requests could be chunked prefill or decode.
            is_first_prefill = False
            if is_prompt:
1609
1610
1611
                seqs = seq_group.get_seqs()
                # Prefill has only 1 sequence.
                assert len(seqs) == 1
1612
1613
                num_computed_tokens = seqs[0].data.get_num_computed_tokens()
                is_first_prefill = num_computed_tokens == 0
1614
1615
1616
1617
1618
                # In the next iteration, all prompt tokens are not computed.
                # It means the prefill is chunked, and we don't need sampling.
                # NOTE: We use get_len instead of get_prompt_len because when
                # a sequence is preempted, prefill includes previous generated
                # output tokens.
1619
1620
                if (token_chunk_size + num_computed_tokens
                        < seqs[0].data.get_len()):
1621
1622
                    do_sample = False

1623
1624
            # It assumes the scheduled_seq_groups is ordered by
            # prefill < decoding.
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
            if is_first_prefill or not self.scheduler_config.send_delta_data:
                seq_group_metadata = SequenceGroupMetadata(
                    request_id=seq_group.request_id,
                    is_prompt=is_prompt,
                    seq_data=seq_data,
                    sampling_params=seq_group.sampling_params,
                    block_tables=block_tables,
                    do_sample=do_sample,
                    pooling_params=seq_group.pooling_params,
                    token_chunk_size=token_chunk_size,
                    lora_request=seq_group.lora_request,
                    computed_block_nums=common_computed_block_nums,
                    encoder_seq_data=encoder_seq_data,
                    cross_block_table=cross_block_table,
                    state=seq_group.state,
1640
                    token_type_ids=seq_group.token_type_ids,
1641
1642
1643
1644
                    # `multi_modal_data` will only be present for the 1st comm
                    # between engine and worker.
                    # the subsequent comms can still use delta, but
                    # `multi_modal_data` will be None.
1645
1646
1647
1648
1649
1650
                    multi_modal_data=(seq_group.multi_modal_data
                                      if scheduler_outputs.num_prefill_groups
                                      > 0 else None),
                    multi_modal_placeholders=(
                        seq_group.multi_modal_placeholders
                        if scheduler_outputs.num_prefill_groups > 0 else None),
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
                    prompt_adapter_request=seq_group.prompt_adapter_request,
                )
            else:
                # When SPMD mode is enabled, we only send delta data except for
                # the first request to reduce serialization cost.
                seq_data_delta = {}
                for id, data in seq_data.items():
                    seq_data_delta[id] = data.get_delta_and_reset()
                seq_group_metadata = SequenceGroupMetadataDelta(
                    seq_data_delta,
                    seq_group.request_id,
                    block_tables,
                    is_prompt,
                    do_sample=do_sample,
                    token_chunk_size=token_chunk_size,
                    computed_block_nums=common_computed_block_nums,
                )
1668
            seq_group_metadata_list.append(seq_group_metadata)
1669

1670
1671
1672
1673
            if allow_async_output_proc:
                allow_async_output_proc = self._allow_async_output_proc(
                    seq_group)

1674
1675
1676
1677
        # Now that the batch has been created, we can assume all blocks in the
        # batch will have been computed before the next scheduling invocation.
        # This is because the engine assumes that a failure in model execution
        # will crash the vLLM instance / will not retry.
1678
1679
        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            self.block_manager.mark_blocks_as_computed(
1680
1681
                scheduled_seq_group.seq_group,
                scheduled_seq_group.token_chunk_size)
1682

1683
1684
        self._seq_group_metadata_cache[self.next_cache_id].reset()

1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
        scheduler_time = time.perf_counter() - scheduler_start_time
        # Add this to scheduler time to all the sequences that are currently
        # running. This will help estimate if the scheduler is a significant
        # component in the e2e latency.
        for seq_group in self.running:
            if seq_group is not None and seq_group.metrics is not None:
                if seq_group.metrics.scheduler_time is not None:
                    seq_group.metrics.scheduler_time += scheduler_time
                else:
                    seq_group.metrics.scheduler_time = scheduler_time

1696
1697
1698
1699
1700
1701
        # Move to next cache (if exists)
        self.cache_id = self.next_cache_id

        # Return results
        return (seq_group_metadata_list, scheduler_outputs,
                allow_async_output_proc)
1702

1703
1704
    def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
        self.block_manager.fork(parent_seq, child_seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
1705

1706
    def free_seq(self, seq: Sequence) -> None:
1707
        """Free a sequence from a block table."""
1708
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
1709

1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
    def remove_seq_from_computed_blocks_tracker(
            self, seq_group: SequenceGroup,
            status: Optional[SequenceStatus]) -> None:
        seqs = seq_group.get_seqs(status=status)
        for seq in seqs:
            self._remove_seq_from_computed_blocks_tracker(seq)

    def _remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
        """
        Free a sequence computed blocks tracker _seq_id_to_blocks_hashes
        and _seq_id_to_num_tokens_computed.
        """
        self.block_manager.remove_seq_from_computed_blocks_tracker(seq)

1724
1725
1726
1727
1728
1729
    def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
        """Free finished seqs in a sequence group."""
        for seq in seq_group.get_seqs():
            if seq.is_finished():
                self.free_seq(seq)

1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
    def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None:
        if seq_group.is_finished():
            # Free cross-attention block table, if it exists
            self._free_seq_group_cross_attn_blocks(seq_group)

            # Add the finished requests to the finished requests list.
            # This list will be used to update the Mamba cache in the
            # next step.
            self._finished_requests_ids.append(seq_group.request_id)

        # Free finished seqs
        self._free_finished_seqs(seq_group)

1743
    def free_finished_seq_groups(self) -> None:
1744
1745
        remaining: Deque[SequenceGroup] = deque()
        for seq_group in self.running:
1746
1747
            self._free_finished_seq_group(seq_group)
            if not seq_group.is_finished():
1748
                remaining.append(seq_group)
1749

1750
        self.running = remaining
Woosuk Kwon's avatar
Woosuk Kwon committed
1751

1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
        # Handle async stopped sequence groups
        # (ones that reached max model len)
        if self._async_stopped:
            for seq_group in self._async_stopped:
                self._free_seq_group_cross_attn_blocks(seq_group)
                self._finished_requests_ids.append(seq_group.request_id)

                # Free finished seqs
                self._free_finished_seqs(seq_group)

            self._async_stopped.clear()

1764
    def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
1765
        self.block_manager.allocate(seq_group)
1766
        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
1767
1768
            seq.status = SequenceStatus.RUNNING

1769
1770
1771
1772
1773
1774
    def _append_slots(
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: List[Tuple[int, int]],
        enable_chunking: bool = False,
    ) -> None:
1775
1776
1777
1778
1779
        """Appends new slots to the sequences in the given sequence group.

        Args:
            seq_group (SequenceGroup): The sequence group containing the
                sequences to append slots to.
1780
1781
1782
1783
1784
            blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
                ints, the first int is the source block index, and the second
                int is the destination block index. This list is updated with
                the new source and destination block indices for the appended
                slots.
1785
            enable_chunking (bool): True if chunked prefill is enabled.
1786
        """
1787
1788
1789
1790
1791
1792
1793
1794
        is_prefill: bool = seq_group.is_prefill()
        num_lookahead_slots: int = self._get_num_lookahead_slots(
            is_prefill, enable_chunking)

        seq_group.init_multi_step_from_lookahead_slots(
            num_lookahead_slots,
            num_scheduler_steps=self.scheduler_config.num_scheduler_steps,
            is_multi_step=self.scheduler_config.is_multi_step,
1795
1796
            enable_chunking=enable_chunking,
        )
1797
1798
1799
1800
1801
1802
1803
1804

        seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
        if self.scheduler_config.is_multi_step and enable_chunking:
            # In multi-step chunked-prefill any sequence type can have
            # slots appended.
            seq_status = None

        for seq in seq_group.get_seqs(status=seq_status):
1805
            cows = self.block_manager.append_slots(seq, num_lookahead_slots)
1806
1807
            if len(cows) > 0:
                blocks_to_copy.extend(cows)
1808

1809
1810
    def _preempt(self, seq_group: SequenceGroup,
                 blocks_to_swap_out: List[Tuple[int, int]]) -> PreemptionMode:
1811
1812
1813
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
1814
1815
        # (e.g., beam search), recomputation is not currently supported. In
        # such a case, we use swapping instead.
1816
1817
1818
1819
1820
1821
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
1822
        if self.user_specified_preemption_mode is None:
1823
            if seq_group.get_max_num_running_seqs() == 1:
1824
1825
1826
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
1827

1828
1829
1830
1831
1832
        elif self.user_specified_preemption_mode == "swap":
            preemption_mode = PreemptionMode.SWAP
        else:
            preemption_mode = PreemptionMode.RECOMPUTE

1833
1834
1835
1836
1837
1838
        if self.num_cumulative_preemption % 50 == 0:
            logger.warning(
                "Sequence group %s is preempted by %s mode because there is "
                "not enough KV cache space. This can affect the end-to-end "
                "performance. Increase gpu_memory_utilization or "
                "tensor_parallel_size to provide more KV cache memory. "
1839
1840
1841
1842
1843
                "total_num_cumulative_preemption=%d",
                seq_group.request_id,
                preemption_mode,
                self.num_cumulative_preemption + 1,
            )
1844
1845
        self.num_cumulative_preemption += 1

1846
1847
1848
1849
1850
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
1851
            raise AssertionError("Invalid preemption mode.")
1852
        return preemption_mode
1853
1854
1855
1856
1857
1858
1859
1860
1861

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
1862
1863
            self.free_seq(seq)
            seq.reset_state_for_recompute()
1864
        self._free_seq_group_cross_attn_blocks(seq_group)
1865
1866
1867
1868

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
1869
        blocks_to_swap_out: List[Tuple[int, int]],
1870
1871
1872
1873
1874
1875
    ) -> None:
        self._swap_out(seq_group, blocks_to_swap_out)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
1876
        blocks_to_swap_in: List[Tuple[int, int]],
1877
1878
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
1879
        blocks_to_swap_in.extend(mapping)
1880
1881
1882
1883
1884
1885
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
1886
        blocks_to_swap_out: List[Tuple[int, int]],
1887
    ) -> None:
1888
1889
1890
1891
1892
1893
        if not self.block_manager.can_swap_out(seq_group):
            # FIXME(woosuk): Abort the sequence group instead of aborting the
            # entire engine.
            raise RuntimeError(
                "Aborted due to the lack of CPU swap space. Please increase "
                "the swap space to avoid this error.")
1894
        mapping = self.block_manager.swap_out(seq_group)
1895
        blocks_to_swap_out.extend(mapping)
1896
1897
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED
1898

1899
1900
1901
1902
1903
1904
1905
1906
    def _passed_delay(self, now: float) -> bool:
        if self.prev_prompt:
            self.last_prompt_latency = now - self.prev_time
        self.prev_time, self.prev_prompt = now, False
        # Delay scheduling prompts to let waiting queue fill up
        if self.scheduler_config.delay_factor > 0 and self.waiting:
            earliest_arrival_time = min(
                [e.metrics.arrival_time for e in self.waiting])
1907
1908
1909
            passed_delay = ((now - earliest_arrival_time)
                            > (self.scheduler_config.delay_factor *
                               self.last_prompt_latency) or not self.running)
1910
1911
1912
        else:
            passed_delay = True
        return passed_delay
1913

1914
1915
    def _get_num_lookahead_slots(self, is_prefill: bool,
                                 enable_chunking: bool) -> int:
1916
1917
1918
1919
1920
1921
        """The number of slots to allocate per sequence per step, beyond known
        token ids. Speculative decoding uses these slots to store KV activations
        of tokens which may or may not be accepted.

        Speculative decoding does not yet support prefill, so we do not perform
        lookahead allocation for prefill.
1922
1923
1924
1925

        When chunking is enabled with multi-step, we allocate lookahead slots
        for the prefills for when the prefills turn into decodes in the first
        step.
1926
1927
        """
        if is_prefill:
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
            if self.scheduler_config.is_multi_step and enable_chunking:
                # num_lookahead_slots was introduced in the context of decodes,
                # in Speculative Decoding.
                # When the num_scheduler_steps is 8, say, then the
                # num_lookahead_slots is 7. Meaning, we are doing a 1-step of
                # decode anyways and we wish to do 7 more.
                #
                # "lookaheads" for prefills, is introduced in support for
                # Chunked-Prefill in Multi-Step.
                return self.scheduler_config.num_lookahead_slots + 1
            else:
                return 0
1940
1941

        return self.scheduler_config.num_lookahead_slots
1942

1943
1944
1945
1946
1947
1948
    def _get_num_new_uncached_and_cached_tokens(
        self,
        seq_group: SequenceGroup,
        status: SequenceStatus,
        enable_chunking: bool,
        budget: SchedulingBudget,
1949
        partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
1950
1951
1952
1953
    ) -> Tuple[int, int]:
        """
        Returns the number of new uncached and cached tokens to schedule for a
        given sequence group that's in a given `status`.
1954
1955
1956
1957
1958

        The API could chunk the number of tokens to compute based on `budget`
        if `enable_chunking` is True. If a sequence group has multiple
        sequences (e.g., running beam search), it means it is in decoding
        phase, so chunking doesn't happen.
1959

1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
        Returns (0, 0) if the new token cannot be computed due to token budget.

        The cached tokens's blocks are already computed, and the attention
        backend will reuse the cached blocks rather than recomputing them. So
        the scheduler could schedule these cached tokens "for free".

        Args:
            seq_group: The sequence group to get the number of new tokens to
                schedule.
            status: The status of the sequences to get the number of new tokens
                to schedule.
            enable_chunking: Whether to chunk the number of tokens to compute.
            budget: The budget to chunk the number of tokens to compute.
1973
1974
            partial_prefill_metadata: information about the partial prefills
                that are currently running
1975
1976
1977
1978
1979
1980


        Returns:
            A tuple of two ints. The first int is the number of new uncached
            tokens to schedule. The second int is the number of cached tokens.
            If no more new tokens can be scheduled, returns (0, 0).
1981
        """
1982
1983
1984
        num_cached_new_tokens = 0
        num_uncached_new_tokens = 0

1985
        seqs = seq_group.get_seqs(status=status)
1986
1987
        # Compute the number of new uncached and cached tokens for
        # each sequence.
1988
        for seq in seqs:
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
            if not seq.is_prefill():
                # Decode sequences should always just have 1 uncached token
                # TODO(rickyx): Actually is this still correct for multi-step?
                num_uncached_new_tokens += 1
                continue

            num_computed_tokens_seq = seq.get_num_computed_tokens()
            all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq
            if not self.cache_config.enable_prefix_caching:
                # If prefix caching is not enabled, all new tokens are uncached.
                num_uncached_new_tokens += all_num_new_tokens_seq
                continue

            # NOTE: the cache token might be currently in a block that's in an
            # evictor meaning that it's not yet allocated. However, we don't
            # exclude such tokens in the cache count because it will be
            # guaranteed to be allocated later if the sequence can be allocated.
            num_cached_tokens_seq = self.block_manager.get_num_cached_tokens(
                seq)

            # Sanity check.
            if num_cached_tokens_seq < num_computed_tokens_seq:
                # This should only happen with chunked prefill, and
                # the seq is still in prefill. The `num_cached_tokens_seq`
                # is the value we calculated on scheduling the first prefill.
                # For subsequent continuous prefill steps, we cached the
                # number of cache tokens for the sequence so the cached token
                # count could be less than the number of computed tokens.
                # See comments on `ComputedBlocksTracker` for more details.
                assert (
                    seq.is_prefill() and seq.status == SequenceStatus.RUNNING
                    and self.scheduler_config.chunked_prefill_enabled
                ), ("Number of cached tokens should not be less than the "
                    "number of computed tokens for a sequence that's still "
                    f"in prefill. But there are {num_cached_tokens_seq} cached "
                    f"tokens and {num_computed_tokens_seq} computed tokens "
                    f"for sequence {seq.seq_id}.")

            num_cached_new_tokens_seq = max(
                0, num_cached_tokens_seq - num_computed_tokens_seq)
            num_uncached_new_tokens_seq = (all_num_new_tokens_seq -
                                           num_cached_new_tokens_seq)

            num_uncached_new_tokens += num_uncached_new_tokens_seq
            num_cached_new_tokens += num_cached_new_tokens_seq

        if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0:
            # For a fully cached hit sequence, we actually need to recompute the
            # last token. So we need at least 1 uncached token to schedule.
            # See ModelRunner._compute_for_prefix_cache_hit for more details.
            num_uncached_new_tokens = 1
            num_cached_new_tokens -= 1

2042
        if enable_chunking and len(seqs) == 1:
2043
2044
2045
2046
2047
2048
2049
2050
2051
            # Chunk if a running request cannot fit in the given budget.
            # If number of seq > 1, it means it is doing beam search
            # in a decode phase. Do not chunk.
            num_uncached_new_tokens = self._chunk_new_tokens_to_schedule(
                self.scheduler_config,
                self.cache_config,
                budget,
                self._get_prompt_limit(seq_group),
                num_uncached_new_tokens,
2052
2053
                self.partial_prefill_budget_lookup_list,
                partial_prefill_metadata,
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
            )

        return num_uncached_new_tokens, num_cached_new_tokens

    @staticmethod
    def _chunk_new_tokens_to_schedule(
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
        budget: SchedulingBudget,
        prompt_limit: int,
        num_new_tokens: int,
2065
2066
        partial_prefill_budget_lookup_list: List[int],
        partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
    ) -> int:
        """
        Chunks the number of new tokens to schedule based on the budget when
        chunked prefill is enabled.

        Args:
            scheduler_config: The scheduler config.
            cache_config: The cache config.
            budget: The budget to chunk the number of tokens to compute.
            prompt_limit: The maximum number of tokens allowed in a prompt.
            num_new_tokens: The number of new tokens to schedule.

        Returns:
            The number of new tokens to schedule after chunking.
        """
        remaining_token_budget = budget.remaining_token_budget()
        if scheduler_config.is_multi_step:
            # The current multi-step + chunked prefill capability does
            # not actually support chunking prompts.
            #
            # Therefore, `num_new_tokens` is computed in the same fashion
            # for both multi-step+chunked-prefill &
            # multi-step+chunked-prefill+APC
            #
            # Prompts with more tokens than the current remaining budget
            # are postponed to future scheduler steps
            if num_new_tokens > prompt_limit:
                # If the seq_group is in prompt-stage, pass the
                # num_new_tokens as-is so the caller can ignore
                # the sequence.
                return num_new_tokens

2099
2100
            return 0 if num_new_tokens > \
                remaining_token_budget else num_new_tokens
2101

2102
2103
2104
2105
2106
        # Get the number of tokens to allocate to this prefill slot
        prefill_slot_budget = (
            remaining_token_budget if partial_prefill_metadata is None else
            partial_prefill_budget_lookup_list[
                partial_prefill_metadata.schedulable_prefills])
2107

2108
2109
2110
2111
        if cache_config.enable_prefix_caching:
            # When prefix caching is enabled and we're partially prefilling
            # a sequence, we always allocate a number of new tokens that is
            # divisible by the block size to avoid partial block matching.
2112
            block_size = cache_config.block_size
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
            # Don't exceed either the total budget or slot budget.
            # Take min of those and get the next lowest multiple of the
            # block size:
            remaining_token_budget = (
                min(remaining_token_budget, prefill_slot_budget) //
                block_size) * block_size
            # NB: In the case where num_new_tokens < budget, we are
            # finishing prefill for this sequence, so we do not need to
            # allocate a full block.

        num_new_tokens = min(num_new_tokens, remaining_token_budget,
                             prefill_slot_budget)
2125

2126
        return num_new_tokens