scheduler.py 86.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
6
import os
import random
7
import time
8
from collections import deque
9
from dataclasses import dataclass, field
10
11
12
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
15
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
19
20
21
                           SequenceGroupBase, SequenceGroupMetadata,
                           SequenceGroupMetadataDelta, SequenceStage,
                           SequenceStatus)
22
from vllm.utils import Device, PyObjectCache
Woosuk Kwon's avatar
Woosuk Kwon committed
23

Woosuk Kwon's avatar
Woosuk Kwon committed
24
logger = init_logger(__name__)
25

26
27
28
29
30
31
32
# Test-only. If configured, decode is preempted with
# ARTIFICIAL_PREEMPTION_PROB% probability.
ENABLE_ARTIFICIAL_PREEMPT = bool(
    os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False))  # noqa
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500

Woosuk Kwon's avatar
Woosuk Kwon committed
33

34
35
36
37
38
39
40
41
42
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
43

44
45
46
47
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


48
49
@dataclass
class SchedulingBudget:
50
51
52
53
54
55
56
57
58
    """The available slots for scheduling.

    TODO(sang): Right now, the budget is request_id-aware meaning it can ignore
    budget update from the same request_id. It is because in normal scheduling
    path, we update RUNNING num_seqs ahead of time, meaning it could be
    updated more than once when scheduling RUNNING requests. Since this won't
    happen if we only have chunked prefill scheduling, we can remove this
    feature from the API when chunked prefill is enabled by default.
    """
59

60
61
    token_budget: int
    max_num_seqs: int
62
63
    _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
    _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
64
65
66
    # Number of cached tokens in the batch.
    _num_cached_tokens: int = 0
    # Number of actual non-cached tokens in the batch.
67
68
    _num_batched_tokens: int = 0
    _num_curr_seqs: int = 0
69
70

    def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
71
72
73
        # We allow num_new_tokens to be 0 when the entire sequence has
        # been cached.
        assert num_new_tokens >= 0
74
        assert num_new_seqs != 0
75
76
77
        return (self.num_batched_tokens + num_new_tokens <= self.token_budget
                and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)

78
79
80
    def remaining_token_budget(self):
        return self.token_budget - self.num_batched_tokens

81
82
83
84
    def add_num_batched_tokens(self,
                               req_id: str,
                               num_batched_tokens: int,
                               num_cached_tokens: int = 0):
85
        if req_id in self._request_ids_num_batched_tokens:
86
            return
87
88
        assert num_cached_tokens >= 0
        assert num_batched_tokens >= 0
89

90
        self._request_ids_num_batched_tokens.add(req_id)
91
        self._num_batched_tokens += num_batched_tokens
92
        self._num_cached_tokens += num_cached_tokens
93
94
95

    def subtract_num_batched_tokens(self, req_id: str,
                                    num_batched_tokens: int):
96
97
        if req_id in self._request_ids_num_batched_tokens:
            self._request_ids_num_batched_tokens.remove(req_id)
98
99
100
            self._num_batched_tokens -= num_batched_tokens

    def add_num_seqs(self, req_id: str, num_curr_seqs: int):
101
        if req_id in self._request_ids_num_curr_seqs:
102
103
            return

104
        self._request_ids_num_curr_seqs.add(req_id)
105
106
107
        self._num_curr_seqs += num_curr_seqs

    def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
108
109
        if req_id in self._request_ids_num_curr_seqs:
            self._request_ids_num_curr_seqs.remove(req_id)
110
111
112
113
114
115
116
117
118
119
            self._num_curr_seqs -= num_curr_seqs

    @property
    def num_batched_tokens(self):
        return self._num_batched_tokens

    @property
    def num_curr_seqs(self):
        return self._num_curr_seqs

120
121
122
123
    @property
    def num_cached_tokens(self):
        return self._num_cached_tokens

124

125
126
127
128
129
130
131
132
133
134
@dataclass
class ScheduledSequenceGroup:
    # A sequence group that's scheduled.
    seq_group: SequenceGroup
    # The total chunk size (number of tokens) to process for next iteration.
    # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
    # chunked, it can be smaller than that.
    token_chunk_size: int


135
@dataclass
136
class SchedulerOutputs:
137
    """The scheduling decision made from a scheduler."""
138

139
    # Scheduled sequence groups.
140
    scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
141
142
143
144
    # Number of prefill groups scheduled.
    num_prefill_groups: int
    # Total number of batched tokens.
    num_batched_tokens: int
145
146
147
148
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]]
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]]
149
150
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]]
151
152
153
154
    # Sequence groups that are going to be ignored.
    ignored_seq_groups: List[SequenceGroup]
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int
155
156
    # The number of requests in the running queue
    running_queue_size: int
157
    preempted: int
158
159

    def __post_init__(self):
160
        # Swap in and swap out should never happen at the same time.
161
        assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
162

163
        self.num_loras: int = len(self.lora_requests)
164
165
166
        if self.num_loras > 0:
            self._sort_by_lora_ids()

167
    def is_empty(self) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
170
        # NOTE: We do not consider the ignored sequence groups.
        return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
                and not self.blocks_to_swap_out and not self.blocks_to_copy)
171

172
    def _sort_by_lora_ids(self):
173
174
175
176
177
178
179
180
181
182
183
184
        assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups)

        def key_fn(group: ScheduledSequenceGroup):
            key = (group.seq_group.lora_int_id, group.seq_group.request_id)
            if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups):
                # Sort sequence groups so that all prefills come before all
                # decodes as required by chunked prefill.
                return (not group.seq_group.is_prefill(), *key)
            return key

        self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
                                           key=key_fn)
185
186
187

    @property
    def lora_requests(self) -> Set[LoRARequest]:
188
189
190
191
192
        return {
            g.seq_group.lora_request
            for g in self.scheduled_seq_groups
            if g.seq_group.lora_request is not None
        }
193

194

195
@dataclass
196
197
198
199
200
201
class SchedulerRunningOutputs:
    """The requests that are scheduled from a running queue.

    Could contain prefill (prefill that's chunked) or decodes. If there's not
    enough memory, it can be preempted (for recompute) or swapped out.
    """
202

203
    # Selected sequences that are running and in a decoding phase.
204
    decode_seq_groups: List[ScheduledSequenceGroup]
205
206
    # Selected sequences that are running and in a prefill phase.
    # I.e., it means the prefill has been chunked.
207
    prefill_seq_groups: List[ScheduledSequenceGroup]
208
209
210
211
212
    # The preempted sequences.
    preempted: List[SequenceGroup]
    # Sequences that are swapped out.
    swapped_out: List[SequenceGroup]
    # The blocks to swap out.
213
    blocks_to_swap_out: List[Tuple[int, int]]
214
    # The blocks to copy.
215
    blocks_to_copy: List[Tuple[int, int]]
216
    # The number of slots for lookahead decoding.
217
218
    num_lookahead_slots: int

219
220
221
222
    # Optimization for fast-access to seq_group lists
    decode_seq_groups_list: List[SequenceGroup]
    prefill_seq_groups_list: List[SequenceGroup]

223
    @classmethod
224
225
226
227
    def create_empty(cls) -> "SchedulerRunningOutputs":
        return SchedulerRunningOutputs(
            decode_seq_groups=[],
            prefill_seq_groups=[],
228
229
            preempted=[],
            swapped_out=[],
230
            blocks_to_swap_out=[],
231
            blocks_to_copy=[],
232
            num_lookahead_slots=0,
233
234
            decode_seq_groups_list=[],
            prefill_seq_groups_list=[],
235
236
237
238
239
        )


@dataclass
class SchedulerSwappedInOutputs:
240
241
242
243
    """The requests that are scheduled from a swap queue.

    Could contain prefill (prefill that's chunked) or decodes.
    """
244

245
246
    # Selected sequences that are going to be swapped in and is in a
    # decoding phase.
247
    decode_seq_groups: List[ScheduledSequenceGroup]
248
249
    # Selected sequences that are going to be swapped in and in a prefill
    # phase. I.e., it means the prefill has been chunked.
250
    prefill_seq_groups: List[ScheduledSequenceGroup]
251
    # The blocks to swap in.
252
    blocks_to_swap_in: List[Tuple[int, int]]
253
    # The blocks to copy.
254
    blocks_to_copy: List[Tuple[int, int]]
255
    # The number of slots for lookahead decoding.
256
    num_lookahead_slots: int
257
258
    # Infeasible sequence groups.
    infeasible_seq_groups: List[SequenceGroup]
259
260
261
262

    @classmethod
    def create_empty(cls) -> "SchedulerSwappedInOutputs":
        return SchedulerSwappedInOutputs(
263
264
            decode_seq_groups=[],
            prefill_seq_groups=[],
265
            blocks_to_swap_in=[],
266
            blocks_to_copy=[],
267
            num_lookahead_slots=0,
268
            infeasible_seq_groups=[],
269
270
271
272
273
        )


@dataclass
class SchedulerPrefillOutputs:
274
275
276
277
278
    """The requests that are scheduled from a waiting queue.

    Could contain a fresh prefill requests or preempted requests that need
    to be recomputed from scratch.
    """
279

280
    # Selected sequences for prefill.
281
    seq_groups: List[ScheduledSequenceGroup]
282
283
284
285
286
287
288
289
290
291
292
293
294
    # Ignored sequence groups.
    ignored_seq_groups: List[SequenceGroup]
    num_lookahead_slots: int

    @classmethod
    def create_empty(cls) -> "SchedulerPrefillOutputs":
        return SchedulerPrefillOutputs(
            seq_groups=[],
            ignored_seq_groups=[],
            num_lookahead_slots=0,
        )


295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def seq_group_metadata_builder():
    return SequenceGroupMetadata(request_id="",
                                 is_prompt=False,
                                 seq_data={},
                                 sampling_params=None,
                                 block_tables={})


def scheduler_running_outputs_builder():
    return SchedulerRunningOutputs(decode_seq_groups=[],
                                   prefill_seq_groups=[],
                                   preempted=[],
                                   swapped_out=[],
                                   blocks_to_swap_out=[],
                                   blocks_to_copy=[],
                                   num_lookahead_slots=0,
                                   prefill_seq_groups_list=[],
                                   decode_seq_groups_list=[])


def scheduled_seq_group_builder():
316
    return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup),
317
318
                                  token_chunk_size=0)
    # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
319
320


321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
@dataclass
class PartialPrefillMetadata:
    """Holds information about the partial prefills that are currently running
    during a single iteration of the Scheduler.
    When chunked prefill is enabled, we allow a certain number of seqs to be
    partially prefilled during each iteration. Having multiple partial prefills
    in flight allows us to minimize TTFT and avoid decode starvation in cases
    where a single sequence group with a very large prompt blocks the queue for
    too many iterations.
    The number of long prefill requests is limited so that smaller
    requests may jump the queue in front of them and get to the decode
    phase faster.
    """

    # A minimum bound on the total number of prefills to be scheduled during
    # this iteration
    schedulable_prefills: int

    # The number of long prefill requests currently running
    long_prefills: int

    scheduler_config: SchedulerConfig

    def can_schedule(self, seq_group: SequenceGroup) -> bool:
        """When concurrent partial prefills are enabled,
        we limit the number of long requests and only accept
        shorter requests from the queue while running them
        concurrently"""
        return not (seq_group.first_seq.get_num_new_tokens()
                    > self.scheduler_config.long_prefill_token_threshold
                    and self.long_prefills
                    >= self.scheduler_config.max_long_partial_prefills
                    and self.scheduler_config.max_num_partial_prefills > 1)

    def maybe_increment_partial_prefills(self,
                                         seq_group: SequenceGroup) -> None:
        # When a new prefill is scheduled, we need to know if it is a
        # long request
        if (seq_group.first_seq.get_num_new_tokens()
                > self.scheduler_config.long_prefill_token_threshold):
            self.long_prefills += 1

    @classmethod
    def from_queues(
        cls,
        running: Deque[SequenceGroup],
        waiting: Deque[SequenceGroup],
        scheduler_config: SchedulerConfig,
    ) -> "PartialPrefillMetadata":
        """Create a PartialPrefillMetadata object from the current state of
        the scheduler's queues.
        This accounts for the currently running prefill requests, and peeks into
        the waiting queue to see if there are more prefills to potentially be
        scheduled during this iteration."""
        prefills = 0
        long_prefills = 0

        waiting_long_prefills = 0

        for sg in running:
            if sg.first_seq.data.stage == SequenceStage.PREFILL:
                prefills += 1
                if (sg.first_seq.get_num_new_tokens()
                        > scheduler_config.long_prefill_token_threshold):
                    long_prefills += 1

        for sg in waiting:
            # Don't bother looping through the rest of the queue if we know
            # there are already at
            # least max_partial_prefills requests to fill
            if prefills >= scheduler_config.max_num_partial_prefills:
                break

            # Don't count long requests from the waiting queue if we aren't
            # going to schedule them anyway
            if (sg.first_seq.get_num_new_tokens()
                    > scheduler_config.long_prefill_token_threshold):
                if (long_prefills + waiting_long_prefills
                        >= scheduler_config.max_long_partial_prefills):
                    continue
                waiting_long_prefills += 1
            prefills += 1

        # NB: long_prefills and waiting_long_prefills are tracked separately.
        # We don't account for the waiting requests here because we need to use
        # this metadata to track how many have actually been scheduled.
        return PartialPrefillMetadata(
            schedulable_prefills=min(
                prefills, scheduler_config.max_num_partial_prefills),
            long_prefills=long_prefills,
            scheduler_config=scheduler_config,
        )


Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
417
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
418
        self,
419
420
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
421
        lora_config: Optional[LoRAConfig],
422
        pipeline_parallel_size: int = 1,
423
        output_proc_callback: Optional[Callable] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
424
    ) -> None:
425
426
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
427
428
429
430
        # Note for LoRA scheduling: the current policy is extremely
        # simple and NOT fair. It can lead to starvation of some
        # LoRAs. This should be improved in the future.
        self.lora_config = lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
431

432
        version = "selfattn"
433
        if (self.scheduler_config.runner_type == "pooling"
434
435
                or self.cache_config.is_attention_free):
            version = "placeholder"
436

437
        BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
438
            version)
439

440
441
442
443
444
445
446
447
        num_gpu_blocks = cache_config.num_gpu_blocks
        if num_gpu_blocks:
            num_gpu_blocks //= pipeline_parallel_size

        num_cpu_blocks = cache_config.num_cpu_blocks
        if num_cpu_blocks:
            num_cpu_blocks //= pipeline_parallel_size

Woosuk Kwon's avatar
Woosuk Kwon committed
448
        # Create the block space manager.
449
        self.block_manager = BlockSpaceManagerImpl(
450
            block_size=self.cache_config.block_size,
451
452
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
453
            sliding_window=self.cache_config.sliding_window,
454
455
            enable_caching=self.cache_config.enable_prefix_caching,
        )
456

457
        # Sequence groups in the WAITING state.
458
        # Contain new prefill or preempted requests.
459
        self.waiting: Deque[SequenceGroup] = deque()
460
        # Sequence groups in the RUNNING state.
461
        # Contain decode requests.
462
        self.running: Deque[SequenceGroup] = deque()
463
        # Sequence groups in the SWAPPED state.
464
        # Contain decode requests that are swapped out.
465
        self.swapped: Deque[SequenceGroup] = deque()
Mor Zusman's avatar
Mor Zusman committed
466
467
468
        # Sequence groups finished requests ids since last step iteration.
        # It lets the model know that any state associated with these requests
        # can and must be released after the current step.
469
        # This is used to evict the finished requests from the Mamba cache.
Mor Zusman's avatar
Mor Zusman committed
470
        self._finished_requests_ids: List[str] = list()
471
472
473
474
475
476
        # Time at previous scheduling step
        self.prev_time = 0.0
        # Did we schedule a prompt at previous step?
        self.prev_prompt = False
        # Latency of the last prompt step
        self.last_prompt_latency = 0.0
477
478
        # preemption mode, RECOMPUTE or SWAP
        self.user_specified_preemption_mode = scheduler_config.preemption_mode
479

480
481
482
483
484
485
        # The following field is test-only. It is used to inject artificial
        # preemption.
        self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
        self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
                                       if self.enable_artificial_preemption
                                       else 0)
486
        self.num_cumulative_preemption: int = 0
487

488
        # Used to cache python objects
489
490
491
492
493
494
495
496
        self._seq_group_metadata_cache: List[PyObjectCache] = []
        self._scheduler_running_outputs_cache: List[PyObjectCache] = []
        self._scheduled_seq_group_cache: List[PyObjectCache] = []

        # For async output processing, we need to swap cache buffers between
        # iterations. I.e. since the output processing is lagged one step,
        # we cannot reuse the cached objects immediately when the schedule()
        # is called again, but only when schedule() is called the second time.
497
498
        self.output_proc_callback = output_proc_callback
        self.use_async_output_proc = self.output_proc_callback is not None
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        self.num_cache_iters = 2 if self.use_async_output_proc else 1

        self.cache_id = 0
        for i in range(self.num_cache_iters):
            self._seq_group_metadata_cache.append(
                PyObjectCache(seq_group_metadata_builder))
            self._scheduler_running_outputs_cache.append(
                PyObjectCache(scheduler_running_outputs_builder))
            self._scheduled_seq_group_cache.append(
                PyObjectCache(scheduled_seq_group_builder))

        # For async postprocessor, the extra decode run cannot be done
        # when the request reaches max_model_len. In this case, the request
        # will be stopped during schedule() call and added to this stop list
        # for processing and deallocation by the free_finished_seq_groups()
        self._async_stopped: List[SequenceGroup] = []

516
517
518
519
520
521
522
523
524
525
526
527
        # List with the chunk sizes to hand out to each sequence depending
        # on how many partial prefills are running. This is slightly faster than
        # running an integer division every time a prefill is scheduled.
        # This splits the budget evenly among all prefills.
        self.partial_prefill_budget_lookup_list = [0] * (
            self.scheduler_config.max_num_partial_prefills + 1)
        self.partial_prefill_budget_lookup_list[0] = (
            scheduler_config.max_num_batched_tokens)
        for i in range(1, self.scheduler_config.max_num_partial_prefills + 1):
            self.partial_prefill_budget_lookup_list[i] = (
                scheduler_config.max_num_batched_tokens // i)

528
529
530
    @property
    def next_cache_id(self):
        return (self.cache_id + 1) % self.num_cache_iters
531

532
533
534
535
    @property
    def lora_enabled(self) -> bool:
        return bool(self.lora_config)

536
537
538
539
540
    @property
    def num_decoding_tokens_per_seq(self) -> int:
        """The number of new tokens."""
        return 1

541
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
542
        # Add sequence groups to the waiting queue.
543
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
544

545
546
547
548
549
550
551
552
553
554
    def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
        # Add sequence groups to the running queue.
        # Only for testing purposes.
        self.running.append(seq_group)

    def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
        # Add sequence groups to the swapped queue.
        # Only for testing purposes.
        self.swapped.append(seq_group)

555
556
557
558
559
    def abort_seq_group(
        self,
        request_id: Union[str, Iterable[str]],
        seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
    ) -> None:
560
561
562
563
564
565
566
567
568
569
570
        """Aborts a sequence group with the given ID.

        Check if the sequence group with the given ID
            is present in any of the state queue.
        If present, remove the sequence group from the state queue.
            Also, if any of the sequences in the sequence group is not finished,
                free the sequence with status `FINISHED_ABORTED`.
        Otherwise, do nothing.

        Args:
            request_id: The ID(s) of the sequence group to abort.
571
            seq_id_to_seq_group: helper for groups with n>1
572
        """
Antoni Baum's avatar
Antoni Baum committed
573
574
575
        if isinstance(request_id, str):
            request_id = (request_id, )
        request_ids = set(request_id)
576
        seq_id_to_seq_group = seq_id_to_seq_group or {}
577
        for state_queue in [self.waiting, self.running, self.swapped]:
ljss's avatar
ljss committed
578
            aborted_groups: List[SequenceGroup] = []
579
            for seq_group in state_queue:
580
581
582
583
584
585
586
587
588
                # When n>1, seq_group.request_id looks like
                # foo_parallel_sample_0, while request_ids is just foo, and we
                # should resolve it as real_request_id to match.
                if seq_group.request_id in seq_id_to_seq_group:
                    real_request_id = seq_id_to_seq_group[
                        seq_group.request_id].group_id
                else:
                    real_request_id = seq_group.request_id
                if real_request_id in request_ids:
589
590
                    # Appending aborted group into pending list.
                    aborted_groups.append(seq_group)
591
592
593
                    # We can't remove real_request_id in request_ids here,
                    # because there may be other seq groups sharing the same
                    # real_request_id
594
595
596
            for aborted_group in aborted_groups:
                # Remove the sequence group from the state queue.
                state_queue.remove(aborted_group)
597
                # Remove the aborted request from the Mamba cache.
598
                self._finished_requests_ids.append(aborted_group.request_id)
ljss's avatar
ljss committed
599
                for seq in aborted_group.get_seqs():
600
601
602
603
                    if seq.is_finished():
                        continue
                    seq.status = SequenceStatus.FINISHED_ABORTED
                    self.free_seq(seq)
604
605
                if aborted_group.request_id in seq_id_to_seq_group:
                    del seq_id_to_seq_group[aborted_group.request_id]
606

607
608
609
610
611
612
613
614
615
616
617
618
619
                self._free_seq_group_cross_attn_blocks(aborted_group)

    def _free_seq_group_cross_attn_blocks(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        """
        Free a sequence group from a cross-attention block table.
        Has no effect on decoder-only models.
        """
        if seq_group.is_encoder_decoder():
            self.block_manager.free_cross(seq_group)

620
    def has_unfinished_seqs(self) -> bool:
621
622
        return (len(self.waiting) != 0 or len(self.running) != 0
                or len(self.swapped) != 0)
623

624
625
626
    def get_prefix_cache_hit_rate(self, device: Device) -> float:
        return self.block_manager.get_prefix_cache_hit_rate(device)

627
628
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.block_manager.reset_prefix_cache(device)
629

630
631
632
    def get_num_unfinished_seq_groups(self) -> int:
        return len(self.waiting) + len(self.running) + len(self.swapped)

Mor Zusman's avatar
Mor Zusman committed
633
634
635
636
637
638
    def get_and_reset_finished_requests_ids(self) -> List[str]:
        """Flushes the list of request ids of previously finished seq_groups."""
        finished_requests_ids = self._finished_requests_ids
        self._finished_requests_ids = list()
        return finished_requests_ids

639
    def _schedule_running(
640
641
642
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
643
        enable_chunking: bool = False,
644
        partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
645
    ) -> SchedulerRunningOutputs:
646
        """Schedule sequence groups that are running.
647

648
        Running queue should include decode and chunked prefill requests.
Woosuk Kwon's avatar
Woosuk Kwon committed
649

650
651
652
653
654
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any decodes are preempted.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any decodes are preempted.
655
656
657
658
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
659
            partial_prefill_metadata: information about the partial prefills
660
                that are currently running
661

662
        Returns:
663
            SchedulerRunningOutputs.
664
        """
665
666
        ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[
            self.cache_id].get_object()
667
668
669
670
671
672
673
674
        ret.blocks_to_swap_out.clear()
        ret.blocks_to_copy.clear()
        ret.decode_seq_groups.clear()
        ret.prefill_seq_groups.clear()
        ret.preempted.clear()
        ret.swapped_out.clear()

        ret.num_lookahead_slots = self._get_num_lookahead_slots(
675
            is_prefill=False, enable_chunking=enable_chunking)
676
677
678
679

        ret.decode_seq_groups_list.clear()
        ret.prefill_seq_groups_list.clear()

680
        # Blocks that need to be swapped or copied before model execution.
681
682
        blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out
        blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy
Woosuk Kwon's avatar
Woosuk Kwon committed
683

684
685
686
687
688
        decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups
        prefill_seq_groups: List[
            ScheduledSequenceGroup] = ret.prefill_seq_groups
        preempted: List[SequenceGroup] = ret.preempted
        swapped_out: List[SequenceGroup] = ret.swapped_out
Woosuk Kwon's avatar
Woosuk Kwon committed
689

690
691
        running_queue = self.running
        assert len(self._async_stopped) == 0
692
693
        while running_queue:
            seq_group = running_queue[0]
694
695
696
697
698
699
700
            # We discard the cached tokens info here because we don't need it
            # for running sequence:
            #   1. If a sequence is running with chunked prefill, the cached
            #      tokens info was already used for the first prefill.
            #   2. If a sequence is running with non-chunked prefill, then
            #      there it's a decoding sequence, and the cached tokens info is
            #      irrelevant.
701
            num_uncached_new_tokens, _ = \
702
                self._get_num_new_uncached_and_cached_tokens(
703
704
705
706
707
708
                seq_group,
                SequenceStatus.RUNNING,
                enable_chunking,
                budget,
                partial_prefill_metadata,
            )
709
710

            num_running_tokens = num_uncached_new_tokens
711
            if num_running_tokens == 0:
712
                # No budget => Stop
713
                break
714
715

            running_queue.popleft()
716
717
718
719
720

            # With async postprocessor, an extra decode run is done
            # to process the final tokens. The check below avoids this extra
            # decode run when the model max len is reached, in order to avoid
            # a memory overflow.
721
722
            if (self.use_async_output_proc and seq_group.seqs[0].get_len()
                    > self.scheduler_config.max_model_len):
723
724
725
                self._async_stopped.append(seq_group)
                continue

726
727
            # NOTE(woosuk): Preemption happens only when there is no available
            # slot to keep all the sequence groups in the RUNNING state.
728
            while not self._can_append_slots(seq_group, enable_chunking):
729
730
                budget.subtract_num_batched_tokens(seq_group.request_id,
                                                   num_running_tokens)
731
                num_running_seqs = seq_group.get_max_num_running_seqs()
732
733
                budget.subtract_num_seqs(seq_group.request_id,
                                         num_running_seqs)
734
735
736

                if (curr_loras is not None and seq_group.lora_int_id > 0
                        and seq_group.lora_int_id in curr_loras):
737
                    curr_loras.remove(seq_group.lora_int_id)
738

739
740
                # Determine victim sequence
                cont_loop = True
741
                if running_queue:
742
                    # Preempt the lowest-priority sequence group.
743
                    victim_seq_group = running_queue.pop()
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
                else:
                    # No other sequence group can be preempted.
                    # Preempt the current sequence group.
                    # Note: This is also where we stop this loop
                    # (since there is nothing else to preempt)
                    victim_seq_group = seq_group
                    cont_loop = False

                # With async postprocessor, before preempting a sequence
                # we need to ensure it has no pending async postprocessor
                do_preempt = True
                if self.use_async_output_proc:
                    assert self.output_proc_callback is not None
                    self.output_proc_callback(
                        request_id=victim_seq_group.request_id)

                    # It may be that the async pending "victim_seq_group"
                    # becomes finished, in which case we simply free it.
                    if victim_seq_group.is_finished():
                        self._free_finished_seq_group(victim_seq_group)
                        do_preempt = False

                # Do preemption
                if do_preempt:
768
769
770
771
772
773
                    preempted_mode = self._preempt(victim_seq_group,
                                                   blocks_to_swap_out)
                    if preempted_mode == PreemptionMode.RECOMPUTE:
                        preempted.append(victim_seq_group)
                    else:
                        swapped_out.append(victim_seq_group)
774
775

                if not cont_loop:
Woosuk Kwon's avatar
Woosuk Kwon committed
776
777
                    break
            else:
778
                self._append_slots(seq_group, blocks_to_copy, enable_chunking)
779
                is_prefill = seq_group.is_prefill()
780

781
782
783
                scheduled_seq_group: ScheduledSequenceGroup = (
                    self._scheduled_seq_group_cache[
                        self.cache_id].get_object())
784
                scheduled_seq_group.seq_group = seq_group
785
                if is_prefill:
786
787
788
                    scheduled_seq_group.token_chunk_size = num_running_tokens
                    prefill_seq_groups.append(scheduled_seq_group)
                    ret.prefill_seq_groups_list.append(seq_group)
789
                else:
790
791
792
793
                    scheduled_seq_group.token_chunk_size = 1
                    decode_seq_groups.append(scheduled_seq_group)
                    ret.decode_seq_groups_list.append(seq_group)

794
795
                budget.add_num_batched_tokens(seq_group.request_id,
                                              num_running_tokens)
796
797
798
799
800
801
802
                # OPTIMIZATION:  Note that get_max_num_running_seqs is
                # expensive. For the default scheduling chase where
                # enable_chunking is False, num_seqs are updated before running
                # this method, so we don't have to update it again here.
                if enable_chunking:
                    num_running_seqs = seq_group.get_max_num_running_seqs()
                    budget.add_num_seqs(seq_group.request_id, num_running_seqs)
803
804
805
                if curr_loras is not None and seq_group.lora_int_id > 0:
                    curr_loras.add(seq_group.lora_int_id)

806
807
        self._scheduler_running_outputs_cache[self.next_cache_id].reset()
        self._scheduled_seq_group_cache[self.next_cache_id].reset()
808
809

        return ret
810

811
812
813
814
    def _schedule_swapped(
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
815
        enable_chunking: bool = False,
816
    ) -> SchedulerSwappedInOutputs:
817
        """Schedule sequence groups that are swapped out.
818

819
820
821
        It schedules swapped requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.
822

823
824
825
826
827
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are swapped in.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are swapped in.
828
829
830
831
832
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.

833
834
835
836
        Returns:
            SchedulerSwappedInOutputs.
        """
        # Blocks that need to be swapped or copied before model execution.
837
        blocks_to_swap_in: List[Tuple[int, int]] = []
838
        blocks_to_copy: List[Tuple[int, int]] = []
839
840
        decode_seq_groups: List[ScheduledSequenceGroup] = []
        prefill_seq_groups: List[ScheduledSequenceGroup] = []
841
        infeasible_seq_groups: List[SequenceGroup] = []
842

843
844
        swapped_queue = self.swapped

845
        leftover_swapped: Deque[SequenceGroup] = deque()
846
847
848
849
        while swapped_queue:
            seq_group = swapped_queue[0]

            # If the sequence group cannot be swapped in, stop.
850
851
            is_prefill = seq_group.is_prefill()
            alloc_status = self.block_manager.can_swap_in(
852
853
                seq_group,
                self._get_num_lookahead_slots(is_prefill, enable_chunking))
854
            if alloc_status == AllocStatus.LATER:
855
                break
856
857
858
859
            elif alloc_status == AllocStatus.NEVER:
                logger.warning(
                    "Failing the request %s because there's not enough kv "
                    "cache blocks to run the entire sequence.",
860
861
                    seq_group.request_id,
                )
862
863
864
865
866
                for seq in seq_group.get_seqs():
                    seq.status = SequenceStatus.FINISHED_IGNORED
                infeasible_seq_groups.append(seq_group)
                swapped_queue.popleft()
                continue
867
868
869
870

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
871
872
873
                assert curr_loras is not None
                assert self.lora_config is not None
                if (lora_int_id > 0 and (lora_int_id not in curr_loras)
874
875
876
877
878
879
880
881
882
883
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_swapped.appendleft(seq_group)
                    swapped_queue.popleft()
                    continue

            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
            num_new_seqs = seq_group.get_max_num_running_seqs()
884
885
886
887
888
889
890
891
892
            num_new_tokens_uncached, num_new_tokens_cached = (
                self._get_num_new_uncached_and_cached_tokens(
                    seq_group, SequenceStatus.SWAPPED, enable_chunking,
                    budget))

            if num_new_tokens_uncached == 0 or not budget.can_schedule(
                    num_new_tokens=num_new_tokens_uncached,
                    num_new_seqs=num_new_seqs,
            ):
893
894
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.SWAPPED)
895
896
897
898
899
900
                break

            if lora_int_id > 0 and curr_loras is not None:
                curr_loras.add(lora_int_id)
            swapped_queue.popleft()
            self._swap_in(seq_group, blocks_to_swap_in)
901
            self._append_slots(seq_group, blocks_to_copy, enable_chunking)
902
903
            if is_prefill:
                prefill_seq_groups.append(
904
905
906
907
908
                    ScheduledSequenceGroup(
                        seq_group,
                        token_chunk_size=num_new_tokens_uncached +
                        num_new_tokens_cached,
                    ))
909
910
911
            else:
                decode_seq_groups.append(
                    ScheduledSequenceGroup(seq_group, token_chunk_size=1))
912
913
914
915
916
            budget.add_num_batched_tokens(
                seq_group.request_id,
                num_batched_tokens=num_new_tokens_uncached,
                num_cached_tokens=num_new_tokens_cached,
            )
917
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
918
919
920

        swapped_queue.extendleft(leftover_swapped)

921
        return SchedulerSwappedInOutputs(
922
923
            decode_seq_groups=decode_seq_groups,
            prefill_seq_groups=prefill_seq_groups,
924
925
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_copy=blocks_to_copy,
926
            num_lookahead_slots=self._get_num_lookahead_slots(
927
                is_prefill=False, enable_chunking=enable_chunking),
928
929
            infeasible_seq_groups=infeasible_seq_groups,
        )
930

931
    def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
932
        if self.scheduler_config.chunked_prefill_enabled:
933
934
            prompt_limit = self.scheduler_config.max_model_len
        else:
935
936
937
938
            prompt_limit = min(
                self.scheduler_config.max_model_len,
                self.scheduler_config.max_num_batched_tokens,
            )
939
940

        # Model is fine tuned with long context. Return the fine tuned max_len.
941
        if seq_group.lora_request and seq_group.lora_request.long_lora_max_len:
942
943
944
945
946
            assert prompt_limit <= seq_group.lora_request.long_lora_max_len
            return seq_group.lora_request.long_lora_max_len
        else:
            return prompt_limit

947
948
    def _get_priority(self,
                      seq_group: SequenceGroup) -> Tuple[Optional[int], float]:
949
        """Get the priority of the sequence group.
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
        Highest preference to user-defined priority, followed by arrival time.
        Args:
            seq_group: The sequence group input.
        Returns:
            The priority of the sequence group.
        """
        return seq_group.priority, seq_group.arrival_time

    def _schedule_priority_preemption(
        self,
        budget: SchedulingBudget,
    ) -> int:
        """Sorts waiting and running queue. Also, force preempt requests
        from the running queue if their priority is lower.
        Priority-based preemption is used with the priority policy.
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are scheduled.
        Returns:
            A count of priority-based preemptions.
        """

        waiting_queue = self.waiting

        running_queue = deque(sorted(self.running, key=self._get_priority))

        blocks_to_swap_out: List[Tuple[int, int]] = []
        force_preemption_count = 0

        if waiting_queue:
            seq_group = waiting_queue.popleft()
            num_new_seqs = seq_group.get_max_num_running_seqs()
982
            num_new_tokens_uncached, _ = \
983
                self._get_num_new_uncached_and_cached_tokens(
984
                seq_group, SequenceStatus.WAITING, False, budget)
985

986
            # Only preempt if priority inversion exists
987
988
            while running_queue and self._get_priority(
                    running_queue[-1]) > self._get_priority(seq_group):
989
                # Only preempt if waiting sequence cannot be allocated
990
                can_allocate = self.block_manager.can_allocate(seq_group)
991
992
993
994
995
996
                if (num_new_tokens_uncached > 0
                        and can_allocate == AllocStatus.OK
                        and budget.can_schedule(
                            num_new_tokens=num_new_tokens_uncached,
                            num_new_seqs=num_new_seqs,
                        )):
997
998
                    break

999
                # Adjust budget to remove the victim sequence group
1000
                vseq_group = running_queue.pop()
1001
1002
1003
1004
1005
                num_running_tokens_uncached, _ = (
                    self._get_num_new_uncached_and_cached_tokens(
                        vseq_group, SequenceStatus.RUNNING, False, budget))
                budget.subtract_num_batched_tokens(
                    vseq_group.request_id, num_running_tokens_uncached)
1006
1007
1008
1009
                num_running_seqs = vseq_group.get_max_num_running_seqs()
                budget.subtract_num_seqs(vseq_group.request_id,
                                         num_running_seqs)

1010
                # Preempt out the victim sequence group
1011
                self._preempt(vseq_group, blocks_to_swap_out)
1012
1013
                waiting_queue.appendleft(vseq_group)
                force_preemption_count += 1
1014
            # Put the sequence back into the waiting queue
1015
1016
            waiting_queue.appendleft(seq_group)

1017
1018
1019
            self.remove_seq_from_computed_blocks_tracker(
                seq_group, SequenceStatus.WAITING)

1020
1021
1022
1023
1024
1025
        waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))

        self.waiting = waiting_queue
        self.running = running_queue
        return force_preemption_count

1026
1027
1028
1029
    def _schedule_prefills(
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
1030
        enable_chunking: bool = False,
1031
        partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
1032
    ) -> SchedulerPrefillOutputs:
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
        """Schedule sequence groups that are in prefill stage.

        Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
        as a new prefill (that starts from beginning -> most recently generated
        tokens).

        It schedules waiting requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.

        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are scheduled.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are scheduled.
1048
1049
1050
1051
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
1052
1053
            partial_prefill_metadata: information about the partial prefills
                that are currently running
1054
1055

        Returns:
1056
            SchedulerPrefillOutputs.
1057
        """
1058
1059
1060
1061
1062
1063
1064
1065
        if budget.remaining_token_budget() == 0:
            # Do nothing: Can't add any more prefill anyway
            return SchedulerPrefillOutputs(
                seq_groups=[],
                ignored_seq_groups=[],
                num_lookahead_slots=self._get_num_lookahead_slots(
                    is_prefill=True, enable_chunking=enable_chunking),
            )
1066
        ignored_seq_groups: List[SequenceGroup] = []
1067
        seq_groups: List[ScheduledSequenceGroup] = []
1068
        using_prompt_embeds: bool = False
1069
1070

        waiting_queue = self.waiting
1071

1072
        leftover_waiting_sequences: Deque[SequenceGroup] = deque()
1073
1074
1075
1076
1077
1078
1079
        while self._passed_delay(time.time()) and waiting_queue:
            seq_group = waiting_queue[0]

            waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
            assert len(waiting_seqs) == 1, (
                "Waiting sequence group should have only one prompt "
                "sequence.")
1080
1081
1082
1083
1084
            if (partial_prefill_metadata is not None
                    and not partial_prefill_metadata.can_schedule(seq_group)):
                leftover_waiting_sequences.appendleft(seq_group)
                waiting_queue.popleft()
                continue
1085
1086
            num_new_tokens_uncached, num_new_tokens_cached = (
                self._get_num_new_uncached_and_cached_tokens(
1087
1088
1089
1090
1091
1092
                    seq_group,
                    SequenceStatus.WAITING,
                    enable_chunking,
                    budget,
                    partial_prefill_metadata=partial_prefill_metadata,
                ))
1093
1094
            num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached

1095
1096
1097
1098
            if not enable_chunking:
                num_prompt_tokens = waiting_seqs[0].get_len()
                assert num_new_tokens == num_prompt_tokens

1099
1100
            prompt_limit = self._get_prompt_limit(seq_group)
            if num_new_tokens > prompt_limit:
1101
                logger.warning(
1102
                    "Input prompt (%d tokens) is too long"
1103
1104
1105
1106
                    " and exceeds limit of %d",
                    num_new_tokens,
                    prompt_limit,
                )
1107
1108
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
1109
1110
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.FINISHED_IGNORED)
1111
1112
1113
1114
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

1115
1116
            num_lookahead_slots: int = 0

1117
            # If the sequence group cannot be allocated, stop.
1118
1119
            can_allocate = self.block_manager.can_allocate(
                seq_group, num_lookahead_slots=num_lookahead_slots)
1120
            if can_allocate == AllocStatus.LATER:
1121
1122
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.WAITING)
1123
1124
1125
                break
            elif can_allocate == AllocStatus.NEVER:
                logger.warning(
1126
1127
                    "Input prompt (%d tokens) + lookahead slots (%d) is "
                    "too long and exceeds the capacity of block_manager",
1128
1129
1130
                    num_new_tokens,
                    num_lookahead_slots,
                )
1131
1132
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
1133
1134
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.FINISHED_IGNORED)
1135
1136
1137
1138
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

1139
1140
1141
1142
1143
            # We cannot mix sequence groups that use prompt embeds and
            # those that do not.
            if len(seq_groups) == 0:
                using_prompt_embeds = seq_group.uses_prompt_embeds()
            if using_prompt_embeds != seq_group.uses_prompt_embeds():
1144
1145
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.WAITING)
1146
1147
1148
1149
                leftover_waiting_sequences.appendleft(seq_group)
                waiting_queue.popleft()
                continue

1150
1151
1152
            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
1153
1154
                assert curr_loras is not None
                assert self.lora_config is not None
1155
1156
1157
1158
1159
                if (self.lora_enabled and lora_int_id > 0
                        and lora_int_id not in curr_loras
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
1160
1161
                    self.remove_seq_from_computed_blocks_tracker(
                        seq_group, SequenceStatus.WAITING)
1162
1163
1164
1165
                    leftover_waiting_sequences.appendleft(seq_group)
                    waiting_queue.popleft()
                    continue

1166
1167
            if (budget.num_batched_tokens
                    >= self.scheduler_config.max_num_batched_tokens):
1168
1169
1170
                # We've reached the budget limit - since there might be
                # continuous prefills in the running queue, we should break
                # to avoid scheduling any new prefills.
1171
1172
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.WAITING)
1173
1174
                break

1175
            num_new_seqs = seq_group.get_max_num_running_seqs()
1176
1177
1178
1179
            if num_new_tokens_uncached == 0 or not budget.can_schedule(
                    num_new_tokens=num_new_tokens_uncached,
                    num_new_seqs=num_new_seqs,
            ):
1180
1181
                self.remove_seq_from_computed_blocks_tracker(
                    seq_group, SequenceStatus.WAITING)
1182
1183
1184
1185
1186
1187
                break

            # Can schedule this request.
            if curr_loras is not None and lora_int_id > 0:
                curr_loras.add(lora_int_id)
            waiting_queue.popleft()
1188
            self._allocate_and_set_running(seq_group)
1189

1190
1191
1192
1193
            if partial_prefill_metadata is not None:
                partial_prefill_metadata.maybe_increment_partial_prefills(
                    seq_group)

1194
1195
            seq_groups.append(
                ScheduledSequenceGroup(seq_group=seq_group,
1196
                                       token_chunk_size=num_new_tokens))
1197
1198
1199
1200
1201
            budget.add_num_batched_tokens(
                seq_group.request_id,
                num_batched_tokens=num_new_tokens_uncached,
                num_cached_tokens=num_new_tokens_cached,
            )
1202
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
1203
1204
1205
1206
1207
1208

        # Queue requests that couldn't be scheduled.
        waiting_queue.extendleft(leftover_waiting_sequences)
        if len(seq_groups) > 0:
            self.prev_prompt = True

1209
        return SchedulerPrefillOutputs(
1210
1211
            seq_groups=seq_groups,
            ignored_seq_groups=ignored_seq_groups,
1212
            num_lookahead_slots=self._get_num_lookahead_slots(
1213
1214
                is_prefill=True, enable_chunking=enable_chunking),
        )
1215

1216
1217
    def _schedule_default(self) -> SchedulerOutputs:
        """Schedule queued requests.
1218

1219
        The current policy is designed to optimize the throughput. First,
1220
1221
1222
1223
1224
1225
1226
1227
1228
        it batches as many prefill requests as possible. And it schedules
        decodes. If there's a pressure on GPU memory, decode requests can
        be swapped or preempted.
        """
        # Include running requests to the budget.
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
1229
1230
1231
1232
1233
        # Make sure we include num running seqs before scheduling prefill,
        # so that we don't schedule beyond max_num_seqs for prefill.
        for seq_group in self.running:
            budget.add_num_seqs(seq_group.request_id,
                                seq_group.get_max_num_running_seqs())
1234
        curr_loras = (set(
1235
            seq_group.lora_int_id for seq_group in self.running
1236
            if seq_group.lora_int_id > 0) if self.lora_enabled else None)
1237

1238
1239
1240
        prefills = SchedulerPrefillOutputs.create_empty()
        running_scheduled = SchedulerRunningOutputs.create_empty()
        swapped_in = SchedulerSwappedInOutputs.create_empty()
1241
1242
1243

        # If any requests are swapped, prioritized swapped requests.
        if not self.swapped:
1244
1245
1246
            prefills = self._schedule_prefills(budget,
                                               curr_loras,
                                               enable_chunking=False)
1247

1248
1249
1250
1251
        if len(prefills.seq_groups
               ) == 0 and self.scheduler_config.policy == "priority":
            self._schedule_priority_preemption(budget)

1252
        # Don't schedule decodes if prefills are scheduled.
1253
1254
        # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
        # only contains decode requests, not chunked prefills.
1255
        if len(prefills.seq_groups) == 0:
1256
1257
1258
            running_scheduled = self._schedule_running(budget,
                                                       curr_loras,
                                                       enable_chunking=False)
1259

1260
1261
            # If any sequence group is preempted, do not swap in any sequence
            # group. because it means there's no slot for new running requests.
1262
1263
1264
1265
            if (len(running_scheduled.preempted) +
                    len(running_scheduled.swapped_out) == 0):
                swapped_in = \
                    self._schedule_swapped(budget, curr_loras)
1266

1267
1268
        assert (budget.num_batched_tokens
                <= self.scheduler_config.max_num_batched_tokens)
1269
1270
1271
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
1272
        self.waiting.extendleft(running_scheduled.preempted)
1273
        # Update new running requests.
1274
1275
1276
1277
1278
1279
1280
1281
1282
        if len(prefills.seq_groups) > 0:
            self.running.extend([s.seq_group for s in prefills.seq_groups])

        self.running.extend(running_scheduled.decode_seq_groups_list)

        if len(swapped_in.decode_seq_groups) > 0:
            self.running.extend(
                [s.seq_group for s in swapped_in.decode_seq_groups])

1283
        # Update swapped requests.
1284
        self.swapped.extend(running_scheduled.swapped_out)
1285
1286
        preempted = len(running_scheduled.preempted) + len(
            running_scheduled.swapped_out)
1287

1288
1289
1290
1291
        # There should be no prefill from running queue because this policy
        # doesn't allow chunked prefills.
        assert len(running_scheduled.prefill_seq_groups) == 0
        assert len(swapped_in.prefill_seq_groups) == 0
1292
1293
1294

        # Merge lists
        num_prefill_groups = len(prefills.seq_groups)
1295
        ignored_seq_groups_for_embeds = list[SequenceGroup]()
1296
1297
1298
        if num_prefill_groups > 0:
            scheduled_seq_groups = prefills.seq_groups
            scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
1299
            ignored_seq_groups_for_embeds.clear()
1300
1301
        else:
            scheduled_seq_groups = running_scheduled.decode_seq_groups
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
            if len(scheduled_seq_groups) > 0:
                using_prompt_embeds = scheduled_seq_groups[
                    0].seq_group.uses_prompt_embeds()
                ignored_seq_groups_for_embeds.clear()
                indices_ignored = list[int]()
                for i, schedule_seq_group in enumerate(scheduled_seq_groups):
                    if using_prompt_embeds !=\
                        schedule_seq_group.seq_group.uses_prompt_embeds():
                        ignored_seq_groups_for_embeds.append(
                            schedule_seq_group.seq_group)
                        indices_ignored.append(i)
                if len(ignored_seq_groups_for_embeds) > 0:
                    scheduled_seq_groups = [
                        group for i, group in enumerate(scheduled_seq_groups)
                        if i not in indices_ignored
                    ]
            else:
                ignored_seq_groups_for_embeds.clear()

1321
1322
1323
1324
1325
1326
        scheduled_seq_groups.extend(swapped_in.decode_seq_groups)

        blocks_to_copy = running_scheduled.blocks_to_copy
        blocks_to_copy.extend(swapped_in.blocks_to_copy)

        ignored_seq_groups = prefills.ignored_seq_groups
1327
        ignored_seq_groups.extend(ignored_seq_groups_for_embeds)
1328
1329
        ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)

1330
        return SchedulerOutputs(
1331
1332
            scheduled_seq_groups=scheduled_seq_groups,
            num_prefill_groups=num_prefill_groups,
1333
1334
            num_batched_tokens=budget.num_batched_tokens +
            budget.num_cached_tokens,
1335
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
1336
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
1337
1338
            blocks_to_copy=blocks_to_copy,
            ignored_seq_groups=ignored_seq_groups,
1339
            num_lookahead_slots=running_scheduled.num_lookahead_slots,
1340
            running_queue_size=len(self.running),
1341
            preempted=preempted,
1342
1343
        )

1344
    def _schedule_chunked_prefill(self) -> SchedulerOutputs:
1345
        """Schedule queued requests.
1346

1347
1348
1349
1350
1351
1352
1353
1354
        Chunked prefill allows to chunk prefill requests, batch them together
        with decode requests. This policy 1. schedule as many decoding requests
        as possible. 2. schedule chunked prefill requests that are not
        finished. 3. schedule swapped request. 4. schedule new prefill
        requests.

        The policy can sustain the high GPU utilization because it can put
        prefill and decodes requests to the same batch, while it improves
1355
        inter token latency because decodes requests don't need to be blocked
1356
1357
1358
1359
1360
1361
        by prefill requests.
        """
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
1362
        curr_loras: Set[int] = set()
1363

1364
1365
        prefills = SchedulerPrefillOutputs.create_empty()
        swapped_in = SchedulerSwappedInOutputs.create_empty()
1366

1367
1368
1369
1370
1371
1372
1373
        # Create partial prefill metadata
        partial_prefill_metadata = PartialPrefillMetadata.from_queues(
            running=self.running,
            waiting=self.waiting,
            scheduler_config=self.scheduler_config,
        )

1374
        # Decoding should be always scheduled first by fcfs.
1375
1376
1377
1378
1379
1380
        running_scheduled = self._schedule_running(
            budget,
            curr_loras,
            enable_chunking=True,
            partial_prefill_metadata=partial_prefill_metadata,
        )
1381
1382
1383
1384
1385

        # Schedule swapped out requests.
        # If preemption happens, it means we don't have space for swap-in.
        if len(running_scheduled.preempted) + len(
                running_scheduled.swapped_out) == 0:
1386
            swapped_in = self._schedule_swapped(budget, curr_loras)
1387

1388
1389
1390
1391
1392
1393
        prefills = self._schedule_prefills(
            budget,
            curr_loras,
            enable_chunking=True,
            partial_prefill_metadata=partial_prefill_metadata,
        )
1394

1395
1396
        assert (budget.num_batched_tokens
                <= self.scheduler_config.max_num_batched_tokens)
1397
1398
1399
1400
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
        self.waiting.extendleft(running_scheduled.preempted)
1401

1402
        # Update new running requests.
1403
1404
1405
        # By default, vLLM scheduler prioritizes prefills.
        # Once chunked prefill is enabled,
        # the policy is changed to prioritize decode requests.
1406
1407
1408
1409
        self.running.extend(
            [s.seq_group for s in swapped_in.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.prefill_seq_groups])
1410
1411
        self.running.extend(
            [s.seq_group for s in running_scheduled.decode_seq_groups])
1412
1413
1414
1415
1416
1417
        # Because multiple prefills may be running concurrently, we need to
        # make sure that prefills which are scheduled to finish are listed
        # before those that won't. This is so that on the next scheduling
        # iteration when they have transitioned to the decode stage, they are
        # properly prioritized over sequences that are still in the prefill
        # stage.
1418
        self.running.extend(
1419
1420
            self._order_finishing_prefills_first(
                running_scheduled.prefill_seq_groups))
1421
1422
        self.running.extend([s.seq_group for s in prefills.seq_groups])

1423
1424
        # Update swapped requests.
        self.swapped.extend(running_scheduled.swapped_out)
1425
        # Put prefills first due to Attention backend ordering assumption.
1426
1427
1428
1429
1430
1431
1432
1433
        scheduled_seq_groups = (prefills.seq_groups +
                                running_scheduled.prefill_seq_groups +
                                swapped_in.prefill_seq_groups +
                                running_scheduled.decode_seq_groups +
                                swapped_in.decode_seq_groups)
        num_prefill_groups = (len(prefills.seq_groups) +
                              len(swapped_in.prefill_seq_groups) +
                              len(running_scheduled.prefill_seq_groups))
1434
        return SchedulerOutputs(
1435
1436
            scheduled_seq_groups=scheduled_seq_groups,
            num_prefill_groups=num_prefill_groups,
1437
1438
            num_batched_tokens=budget.num_batched_tokens +
            budget.num_cached_tokens,
1439
1440
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
1441
1442
            blocks_to_copy=running_scheduled.blocks_to_copy +
            swapped_in.blocks_to_copy,
1443
1444
            ignored_seq_groups=prefills.ignored_seq_groups +
            swapped_in.infeasible_seq_groups,
1445
            num_lookahead_slots=0,
1446
            running_queue_size=len(self.running),
1447
1448
            preempted=(len(running_scheduled.preempted) +
                       len(running_scheduled.swapped_out)),
1449
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
1450

1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
    def _order_finishing_prefills_first(
        self, scheduled_prefill_seqs: List[ScheduledSequenceGroup]
    ) -> List[SequenceGroup]:
        """Returns a list of prefilling SequenceGroups where sequences that are
        scheduled to finish prefilling are listed first"""
        finishing = [
            s.seq_group for s in scheduled_prefill_seqs
            if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size
        ]
        not_finishing = [
            s.seq_group for s in scheduled_prefill_seqs
            if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size
        ]
        return finishing + not_finishing

1466
1467
1468
1469
1470
1471
1472
    def _schedule(self) -> SchedulerOutputs:
        """Schedule queued requests."""
        if self.scheduler_config.chunked_prefill_enabled:
            return self._schedule_chunked_prefill()
        else:
            return self._schedule_default()

1473
1474
    def _can_append_slots(self, seq_group: SequenceGroup,
                          enable_chunking: bool) -> bool:
1475
1476
1477
        """Determine whether or not we have enough space in the KV cache to
        continue generation of the sequence group.
        """
1478
1479
1480
1481
1482
1483
1484
        # It is True only for testing case to trigger artificial preemption.
        if (self.enable_artificial_preemption
                and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
                and self.artificial_preempt_cnt > 0):
            self.artificial_preempt_cnt -= 1
            return False

1485
1486
1487
1488
        is_prefill = seq_group.is_prefill()
        num_lookahead_slots = self._get_num_lookahead_slots(
            is_prefill, enable_chunking)

1489
        return self.block_manager.can_append_slots(
1490
            seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
1491

1492
    def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
1493
1494
1495
        # async_output_proc is allowed only when we have a single sequence
        # in the sequence group
        no_single_seq = seq_group.sampling_params is None or (
1496
            seq_group.sampling_params.n == 1)
1497
        return no_single_seq
1498
1499
1500
1501

    def schedule(
            self
    ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
1502
1503
1504
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
1505
        scheduler_start_time = time.perf_counter()
1506

1507
        scheduler_outputs: SchedulerOutputs = self._schedule()
1508
        now = time.time()
1509

1510
1511
1512
        if not self.cache_config.enable_prefix_caching:
            common_computed_block_nums = []

1513
        allow_async_output_proc: bool = self.use_async_output_proc
1514

1515
        # Create input data structures.
1516
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
1517
1518
        for i, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
1519
1520
            seq_group = scheduled_seq_group.seq_group
            token_chunk_size = scheduled_seq_group.token_chunk_size
1521
1522
            seq_group.maybe_set_first_scheduled_time(now)

1523
1524
1525
1526
1527
            seq_group_metadata = self._seq_group_metadata_cache[
                self.cache_id].get_object()
            seq_group_metadata.seq_data.clear()
            seq_group_metadata.block_tables.clear()

1528
            # seq_id -> SequenceData
1529
            seq_data: Dict[int, SequenceData] = {}
1530
            # seq_id -> physical block numbers
1531
            block_tables: Dict[int, List[int]] = {}
1532

1533
1534
            if seq_group.is_encoder_decoder():
                # Encoder associated with SequenceGroup
1535
1536
1537
                encoder_seq = seq_group.get_encoder_seq()
                assert encoder_seq is not None
                encoder_seq_data = encoder_seq.data
1538
1539
1540
1541
1542
1543
1544
1545
                # Block table for cross-attention
                # Also managed at SequenceGroup level
                cross_block_table = self.block_manager.get_cross_block_table(
                    seq_group)
            else:
                encoder_seq_data = None
                cross_block_table = None

1546
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
1547
                seq_id = seq.seq_id
1548
                seq_data[seq_id] = seq.data
1549
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
1550
                self.block_manager.access_all_blocks_in_seq(seq, now)
1551

1552
1553
1554
1555
            if self.cache_config.enable_prefix_caching:
                common_computed_block_nums = (
                    self.block_manager.get_common_computed_block_ids(
                        seq_group.get_seqs(status=SequenceStatus.RUNNING)))
1556

1557
            do_sample = True
1558
1559
1560
1561
1562
            is_prompt = seq_group.is_prefill()
            # We should send the metadata to workers when the first prefill
            # is sent. Subsequent requests could be chunked prefill or decode.
            is_first_prefill = False
            if is_prompt:
1563
1564
1565
                seqs = seq_group.get_seqs()
                # Prefill has only 1 sequence.
                assert len(seqs) == 1
1566
1567
                num_computed_tokens = seqs[0].data.get_num_computed_tokens()
                is_first_prefill = num_computed_tokens == 0
1568
1569
1570
1571
1572
                # In the next iteration, all prompt tokens are not computed.
                # It means the prefill is chunked, and we don't need sampling.
                # NOTE: We use get_len instead of get_prompt_len because when
                # a sequence is preempted, prefill includes previous generated
                # output tokens.
1573
1574
                if (token_chunk_size + num_computed_tokens
                        < seqs[0].data.get_len()):
1575
1576
                    do_sample = False

1577
1578
            # It assumes the scheduled_seq_groups is ordered by
            # prefill < decoding.
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
            if is_first_prefill or not self.scheduler_config.send_delta_data:
                seq_group_metadata = SequenceGroupMetadata(
                    request_id=seq_group.request_id,
                    is_prompt=is_prompt,
                    seq_data=seq_data,
                    sampling_params=seq_group.sampling_params,
                    block_tables=block_tables,
                    do_sample=do_sample,
                    pooling_params=seq_group.pooling_params,
                    token_chunk_size=token_chunk_size,
                    lora_request=seq_group.lora_request,
                    computed_block_nums=common_computed_block_nums,
                    encoder_seq_data=encoder_seq_data,
                    cross_block_table=cross_block_table,
                    state=seq_group.state,
1594
                    token_type_ids=seq_group.token_type_ids,
1595
1596
1597
1598
                    # `multi_modal_data` will only be present for the 1st comm
                    # between engine and worker.
                    # the subsequent comms can still use delta, but
                    # `multi_modal_data` will be None.
1599
1600
1601
1602
1603
1604
                    multi_modal_data=(seq_group.multi_modal_data
                                      if scheduler_outputs.num_prefill_groups
                                      > 0 else None),
                    multi_modal_placeholders=(
                        seq_group.multi_modal_placeholders
                        if scheduler_outputs.num_prefill_groups > 0 else None),
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
                )
            else:
                # When SPMD mode is enabled, we only send delta data except for
                # the first request to reduce serialization cost.
                seq_data_delta = {}
                for id, data in seq_data.items():
                    seq_data_delta[id] = data.get_delta_and_reset()
                seq_group_metadata = SequenceGroupMetadataDelta(
                    seq_data_delta,
                    seq_group.request_id,
                    block_tables,
                    is_prompt,
                    do_sample=do_sample,
                    token_chunk_size=token_chunk_size,
                    computed_block_nums=common_computed_block_nums,
                )
1621
            seq_group_metadata_list.append(seq_group_metadata)
1622

1623
1624
1625
1626
            if allow_async_output_proc:
                allow_async_output_proc = self._allow_async_output_proc(
                    seq_group)

1627
1628
1629
1630
        # Now that the batch has been created, we can assume all blocks in the
        # batch will have been computed before the next scheduling invocation.
        # This is because the engine assumes that a failure in model execution
        # will crash the vLLM instance / will not retry.
1631
1632
        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            self.block_manager.mark_blocks_as_computed(
1633
1634
                scheduled_seq_group.seq_group,
                scheduled_seq_group.token_chunk_size)
1635

1636
1637
        self._seq_group_metadata_cache[self.next_cache_id].reset()

1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
        scheduler_time = time.perf_counter() - scheduler_start_time
        # Add this to scheduler time to all the sequences that are currently
        # running. This will help estimate if the scheduler is a significant
        # component in the e2e latency.
        for seq_group in self.running:
            if seq_group is not None and seq_group.metrics is not None:
                if seq_group.metrics.scheduler_time is not None:
                    seq_group.metrics.scheduler_time += scheduler_time
                else:
                    seq_group.metrics.scheduler_time = scheduler_time

1649
1650
1651
1652
1653
1654
        # Move to next cache (if exists)
        self.cache_id = self.next_cache_id

        # Return results
        return (seq_group_metadata_list, scheduler_outputs,
                allow_async_output_proc)
1655

1656
1657
    def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
        self.block_manager.fork(parent_seq, child_seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
1658

1659
    def free_seq(self, seq: Sequence) -> None:
1660
        """Free a sequence from a block table."""
1661
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
1662

1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
    def remove_seq_from_computed_blocks_tracker(
            self, seq_group: SequenceGroup,
            status: Optional[SequenceStatus]) -> None:
        seqs = seq_group.get_seqs(status=status)
        for seq in seqs:
            self._remove_seq_from_computed_blocks_tracker(seq)

    def _remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
        """
        Free a sequence computed blocks tracker _seq_id_to_blocks_hashes
        and _seq_id_to_num_tokens_computed.
        """
        self.block_manager.remove_seq_from_computed_blocks_tracker(seq)

1677
1678
1679
1680
1681
1682
    def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
        """Free finished seqs in a sequence group."""
        for seq in seq_group.get_seqs():
            if seq.is_finished():
                self.free_seq(seq)

1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
    def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None:
        if seq_group.is_finished():
            # Free cross-attention block table, if it exists
            self._free_seq_group_cross_attn_blocks(seq_group)

            # Add the finished requests to the finished requests list.
            # This list will be used to update the Mamba cache in the
            # next step.
            self._finished_requests_ids.append(seq_group.request_id)

        # Free finished seqs
        self._free_finished_seqs(seq_group)

1696
    def free_finished_seq_groups(self) -> None:
1697
1698
        remaining: Deque[SequenceGroup] = deque()
        for seq_group in self.running:
1699
1700
            self._free_finished_seq_group(seq_group)
            if not seq_group.is_finished():
1701
                remaining.append(seq_group)
1702

1703
        self.running = remaining
Woosuk Kwon's avatar
Woosuk Kwon committed
1704

1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
        # Handle async stopped sequence groups
        # (ones that reached max model len)
        if self._async_stopped:
            for seq_group in self._async_stopped:
                self._free_seq_group_cross_attn_blocks(seq_group)
                self._finished_requests_ids.append(seq_group.request_id)

                # Free finished seqs
                self._free_finished_seqs(seq_group)

            self._async_stopped.clear()

1717
    def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
1718
        self.block_manager.allocate(seq_group)
1719
        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
1720
1721
            seq.status = SequenceStatus.RUNNING

1722
1723
1724
1725
1726
1727
    def _append_slots(
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: List[Tuple[int, int]],
        enable_chunking: bool = False,
    ) -> None:
1728
1729
1730
1731
1732
        """Appends new slots to the sequences in the given sequence group.

        Args:
            seq_group (SequenceGroup): The sequence group containing the
                sequences to append slots to.
1733
1734
1735
1736
1737
            blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
                ints, the first int is the source block index, and the second
                int is the destination block index. This list is updated with
                the new source and destination block indices for the appended
                slots.
1738
            enable_chunking (bool): True if chunked prefill is enabled.
1739
        """
1740
1741
1742
1743
1744
1745
        is_prefill: bool = seq_group.is_prefill()
        num_lookahead_slots: int = self._get_num_lookahead_slots(
            is_prefill, enable_chunking)

        seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
        for seq in seq_group.get_seqs(status=seq_status):
1746
            cows = self.block_manager.append_slots(seq, num_lookahead_slots)
1747
1748
            if len(cows) > 0:
                blocks_to_copy.extend(cows)
1749

1750
1751
    def _preempt(self, seq_group: SequenceGroup,
                 blocks_to_swap_out: List[Tuple[int, int]]) -> PreemptionMode:
1752
1753
1754
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
1755
1756
        # (e.g., beam search), recomputation is not currently supported. In
        # such a case, we use swapping instead.
1757
1758
1759
1760
1761
1762
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
1763
        if self.user_specified_preemption_mode is None:
1764
            if seq_group.get_max_num_running_seqs() == 1:
1765
1766
1767
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
1768

1769
1770
1771
1772
1773
        elif self.user_specified_preemption_mode == "swap":
            preemption_mode = PreemptionMode.SWAP
        else:
            preemption_mode = PreemptionMode.RECOMPUTE

1774
1775
1776
1777
1778
1779
        if self.num_cumulative_preemption % 50 == 0:
            logger.warning(
                "Sequence group %s is preempted by %s mode because there is "
                "not enough KV cache space. This can affect the end-to-end "
                "performance. Increase gpu_memory_utilization or "
                "tensor_parallel_size to provide more KV cache memory. "
1780
1781
1782
1783
1784
                "total_num_cumulative_preemption=%d",
                seq_group.request_id,
                preemption_mode,
                self.num_cumulative_preemption + 1,
            )
1785
1786
        self.num_cumulative_preemption += 1

1787
1788
1789
1790
1791
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
1792
            raise AssertionError("Invalid preemption mode.")
1793
        return preemption_mode
1794
1795
1796
1797
1798
1799
1800
1801
1802

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
1803
1804
            self.free_seq(seq)
            seq.reset_state_for_recompute()
1805
        self._free_seq_group_cross_attn_blocks(seq_group)
1806
1807
1808
1809

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
1810
        blocks_to_swap_out: List[Tuple[int, int]],
1811
1812
1813
1814
1815
1816
    ) -> None:
        self._swap_out(seq_group, blocks_to_swap_out)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
1817
        blocks_to_swap_in: List[Tuple[int, int]],
1818
1819
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
1820
        blocks_to_swap_in.extend(mapping)
1821
1822
1823
1824
1825
1826
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
1827
        blocks_to_swap_out: List[Tuple[int, int]],
1828
    ) -> None:
1829
1830
1831
1832
1833
1834
        if not self.block_manager.can_swap_out(seq_group):
            # FIXME(woosuk): Abort the sequence group instead of aborting the
            # entire engine.
            raise RuntimeError(
                "Aborted due to the lack of CPU swap space. Please increase "
                "the swap space to avoid this error.")
1835
        mapping = self.block_manager.swap_out(seq_group)
1836
        blocks_to_swap_out.extend(mapping)
1837
1838
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED
1839

1840
1841
1842
1843
1844
1845
1846
1847
    def _passed_delay(self, now: float) -> bool:
        if self.prev_prompt:
            self.last_prompt_latency = now - self.prev_time
        self.prev_time, self.prev_prompt = now, False
        # Delay scheduling prompts to let waiting queue fill up
        if self.scheduler_config.delay_factor > 0 and self.waiting:
            earliest_arrival_time = min(
                [e.metrics.arrival_time for e in self.waiting])
1848
1849
1850
            passed_delay = ((now - earliest_arrival_time)
                            > (self.scheduler_config.delay_factor *
                               self.last_prompt_latency) or not self.running)
1851
1852
1853
        else:
            passed_delay = True
        return passed_delay
1854

1855
1856
    def _get_num_lookahead_slots(self, is_prefill: bool,
                                 enable_chunking: bool) -> int:
1857
1858
1859
1860
        """The number of slots to allocate per sequence per step, beyond known
        token ids. Speculative decoding uses these slots to store KV activations
        of tokens which may or may not be accepted.
        """
1861
        return 0
1862

1863
1864
1865
1866
1867
1868
    def _get_num_new_uncached_and_cached_tokens(
        self,
        seq_group: SequenceGroup,
        status: SequenceStatus,
        enable_chunking: bool,
        budget: SchedulingBudget,
1869
        partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
1870
1871
1872
1873
    ) -> Tuple[int, int]:
        """
        Returns the number of new uncached and cached tokens to schedule for a
        given sequence group that's in a given `status`.
1874
1875
1876
1877
1878

        The API could chunk the number of tokens to compute based on `budget`
        if `enable_chunking` is True. If a sequence group has multiple
        sequences (e.g., running beam search), it means it is in decoding
        phase, so chunking doesn't happen.
1879

1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
        Returns (0, 0) if the new token cannot be computed due to token budget.

        The cached tokens's blocks are already computed, and the attention
        backend will reuse the cached blocks rather than recomputing them. So
        the scheduler could schedule these cached tokens "for free".

        Args:
            seq_group: The sequence group to get the number of new tokens to
                schedule.
            status: The status of the sequences to get the number of new tokens
                to schedule.
            enable_chunking: Whether to chunk the number of tokens to compute.
            budget: The budget to chunk the number of tokens to compute.
1893
1894
            partial_prefill_metadata: information about the partial prefills
                that are currently running
1895
1896
1897
1898
1899
1900


        Returns:
            A tuple of two ints. The first int is the number of new uncached
            tokens to schedule. The second int is the number of cached tokens.
            If no more new tokens can be scheduled, returns (0, 0).
1901
        """
1902
1903
1904
        num_cached_new_tokens = 0
        num_uncached_new_tokens = 0

1905
        seqs = seq_group.get_seqs(status=status)
1906
1907
        # Compute the number of new uncached and cached tokens for
        # each sequence.
1908
        for seq in seqs:
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
            if not seq.is_prefill():
                # Decode sequences should always just have 1 uncached token
                # TODO(rickyx): Actually is this still correct for multi-step?
                num_uncached_new_tokens += 1
                continue

            num_computed_tokens_seq = seq.get_num_computed_tokens()
            all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq
            if not self.cache_config.enable_prefix_caching:
                # If prefix caching is not enabled, all new tokens are uncached.
                num_uncached_new_tokens += all_num_new_tokens_seq
                continue

            # NOTE: the cache token might be currently in a block that's in an
            # evictor meaning that it's not yet allocated. However, we don't
            # exclude such tokens in the cache count because it will be
            # guaranteed to be allocated later if the sequence can be allocated.
            num_cached_tokens_seq = self.block_manager.get_num_cached_tokens(
                seq)

            # Sanity check.
            if num_cached_tokens_seq < num_computed_tokens_seq:
                # This should only happen with chunked prefill, and
                # the seq is still in prefill. The `num_cached_tokens_seq`
                # is the value we calculated on scheduling the first prefill.
                # For subsequent continuous prefill steps, we cached the
                # number of cache tokens for the sequence so the cached token
                # count could be less than the number of computed tokens.
                # See comments on `ComputedBlocksTracker` for more details.
                assert (
                    seq.is_prefill() and seq.status == SequenceStatus.RUNNING
                    and self.scheduler_config.chunked_prefill_enabled
                ), ("Number of cached tokens should not be less than the "
                    "number of computed tokens for a sequence that's still "
                    f"in prefill. But there are {num_cached_tokens_seq} cached "
                    f"tokens and {num_computed_tokens_seq} computed tokens "
                    f"for sequence {seq.seq_id}.")

            num_cached_new_tokens_seq = max(
                0, num_cached_tokens_seq - num_computed_tokens_seq)
            num_uncached_new_tokens_seq = (all_num_new_tokens_seq -
                                           num_cached_new_tokens_seq)

            num_uncached_new_tokens += num_uncached_new_tokens_seq
            num_cached_new_tokens += num_cached_new_tokens_seq

        if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0:
            # For a fully cached hit sequence, we actually need to recompute the
            # last token. So we need at least 1 uncached token to schedule.
            # See ModelRunner._compute_for_prefix_cache_hit for more details.
            num_uncached_new_tokens = 1
            num_cached_new_tokens -= 1

1962
        if enable_chunking and len(seqs) == 1:
1963
1964
1965
1966
1967
1968
1969
1970
1971
            # Chunk if a running request cannot fit in the given budget.
            # If number of seq > 1, it means it is doing beam search
            # in a decode phase. Do not chunk.
            num_uncached_new_tokens = self._chunk_new_tokens_to_schedule(
                self.scheduler_config,
                self.cache_config,
                budget,
                self._get_prompt_limit(seq_group),
                num_uncached_new_tokens,
1972
1973
                self.partial_prefill_budget_lookup_list,
                partial_prefill_metadata,
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
            )

        return num_uncached_new_tokens, num_cached_new_tokens

    @staticmethod
    def _chunk_new_tokens_to_schedule(
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
        budget: SchedulingBudget,
        prompt_limit: int,
        num_new_tokens: int,
1985
1986
        partial_prefill_budget_lookup_list: List[int],
        partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
    ) -> int:
        """
        Chunks the number of new tokens to schedule based on the budget when
        chunked prefill is enabled.

        Args:
            scheduler_config: The scheduler config.
            cache_config: The cache config.
            budget: The budget to chunk the number of tokens to compute.
            prompt_limit: The maximum number of tokens allowed in a prompt.
            num_new_tokens: The number of new tokens to schedule.

        Returns:
            The number of new tokens to schedule after chunking.
        """
        remaining_token_budget = budget.remaining_token_budget()

2004
2005
2006
2007
2008
        # Get the number of tokens to allocate to this prefill slot
        prefill_slot_budget = (
            remaining_token_budget if partial_prefill_metadata is None else
            partial_prefill_budget_lookup_list[
                partial_prefill_metadata.schedulable_prefills])
2009

2010
2011
2012
2013
        if cache_config.enable_prefix_caching:
            # When prefix caching is enabled and we're partially prefilling
            # a sequence, we always allocate a number of new tokens that is
            # divisible by the block size to avoid partial block matching.
2014
            block_size = cache_config.block_size
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
            # Don't exceed either the total budget or slot budget.
            # Take min of those and get the next lowest multiple of the
            # block size:
            remaining_token_budget = (
                min(remaining_token_budget, prefill_slot_budget) //
                block_size) * block_size
            # NB: In the case where num_new_tokens < budget, we are
            # finishing prefill for this sequence, so we do not need to
            # allocate a full block.

        num_new_tokens = min(num_new_tokens, remaining_token_budget,
                             prefill_slot_budget)
2027

2028
        return num_new_tokens