scheduler.py 71 KB
Newer Older
1
import enum
2
3
import os
import random
4
import time
5
from collections import deque
6
from dataclasses import dataclass, field
7
8
9
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
12
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.logger import init_logger
14
from vllm.lora.request import LoRARequest
15
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
17
18
                           SequenceGroupMetadata, SequenceGroupMetadataDelta,
                           SequenceStatus)
19
from vllm.utils import Device, PyObjectCache
Woosuk Kwon's avatar
Woosuk Kwon committed
20

Woosuk Kwon's avatar
Woosuk Kwon committed
21
logger = init_logger(__name__)
22

23
24
25
26
27
28
29
# Test-only. If configured, decode is preempted with
# ARTIFICIAL_PREEMPTION_PROB% probability.
ENABLE_ARTIFICIAL_PREEMPT = bool(
    os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False))  # noqa
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500

Woosuk Kwon's avatar
Woosuk Kwon committed
30

31
32
33
34
35
36
37
38
39
40
41
42
43
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


44
45
@dataclass
class SchedulingBudget:
46
47
48
49
50
51
52
53
54
    """The available slots for scheduling.

    TODO(sang): Right now, the budget is request_id-aware meaning it can ignore
    budget update from the same request_id. It is because in normal scheduling
    path, we update RUNNING num_seqs ahead of time, meaning it could be
    updated more than once when scheduling RUNNING requests. Since this won't
    happen if we only have chunked prefill scheduling, we can remove this
    feature from the API when chunked prefill is enabled by default.
    """
55
56
    token_budget: int
    max_num_seqs: int
57
58
    _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
    _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
59
60
    _num_batched_tokens: int = 0
    _num_curr_seqs: int = 0
61
62

    def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
63
64
        assert num_new_tokens != 0
        assert num_new_seqs != 0
65
66
67
        return (self.num_batched_tokens + num_new_tokens <= self.token_budget
                and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)

68
69
70
71
    def remaining_token_budget(self):
        return self.token_budget - self.num_batched_tokens

    def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
72
        if req_id in self._request_ids_num_batched_tokens:
73
74
            return

75
        self._request_ids_num_batched_tokens.add(req_id)
76
77
78
79
        self._num_batched_tokens += num_batched_tokens

    def subtract_num_batched_tokens(self, req_id: str,
                                    num_batched_tokens: int):
80
81
        if req_id in self._request_ids_num_batched_tokens:
            self._request_ids_num_batched_tokens.remove(req_id)
82
83
84
            self._num_batched_tokens -= num_batched_tokens

    def add_num_seqs(self, req_id: str, num_curr_seqs: int):
85
        if req_id in self._request_ids_num_curr_seqs:
86
87
            return

88
        self._request_ids_num_curr_seqs.add(req_id)
89
90
91
        self._num_curr_seqs += num_curr_seqs

    def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
92
93
        if req_id in self._request_ids_num_curr_seqs:
            self._request_ids_num_curr_seqs.remove(req_id)
94
95
96
97
98
99
100
101
102
103
            self._num_curr_seqs -= num_curr_seqs

    @property
    def num_batched_tokens(self):
        return self._num_batched_tokens

    @property
    def num_curr_seqs(self):
        return self._num_curr_seqs

104

105
106
107
108
109
110
111
112
113
114
@dataclass
class ScheduledSequenceGroup:
    # A sequence group that's scheduled.
    seq_group: SequenceGroup
    # The total chunk size (number of tokens) to process for next iteration.
    # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
    # chunked, it can be smaller than that.
    token_chunk_size: int


115
@dataclass
116
class SchedulerOutputs:
117
    """The scheduling decision made from a scheduler."""
118
    # Scheduled sequence groups.
119
    scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
120
121
122
123
    # Number of prefill groups scheduled.
    num_prefill_groups: int
    # Total number of batched tokens.
    num_batched_tokens: int
124
125
126
127
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]]
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]]
128
129
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]]
130
131
132
133
    # Sequence groups that are going to be ignored.
    ignored_seq_groups: List[SequenceGroup]
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int
134
135
    # The number of requests in the running queue
    running_queue_size: int
136
    preempted: int
137
138

    def __post_init__(self):
139
        # Swap in and swap out should never happen at the same time.
140
        assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
141

142
        self.num_loras: int = len(self.lora_requests)
143
144
145
        if self.num_loras > 0:
            self._sort_by_lora_ids()

146
147
        self.num_prompt_adapters: int = len(self.prompt_adapter_requests)

148
    def is_empty(self) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151
        # NOTE: We do not consider the ignored sequence groups.
        return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
                and not self.blocks_to_swap_out and not self.blocks_to_copy)
152

153
    def _sort_by_lora_ids(self):
154
155
156
        self.scheduled_seq_groups = sorted(
            self.scheduled_seq_groups,
            key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
157
158
159

    @property
    def lora_requests(self) -> Set[LoRARequest]:
160
161
162
163
164
        return {
            g.seq_group.lora_request
            for g in self.scheduled_seq_groups
            if g.seq_group.lora_request is not None
        }
165

166
167
168
169
170
171
172
173
    @property
    def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
        return {
            g.seq_group.prompt_adapter_request
            for g in self.scheduled_seq_groups
            if g.seq_group.prompt_adapter_request is not None
        }

174

175
@dataclass
176
177
178
179
180
181
182
class SchedulerRunningOutputs:
    """The requests that are scheduled from a running queue.

    Could contain prefill (prefill that's chunked) or decodes. If there's not
    enough memory, it can be preempted (for recompute) or swapped out.
    """
    # Selected sequences that are running and in a decoding phase.
183
    decode_seq_groups: List[ScheduledSequenceGroup]
184
185
    # Selected sequences that are running and in a prefill phase.
    # I.e., it means the prefill has been chunked.
186
    prefill_seq_groups: List[ScheduledSequenceGroup]
187
188
189
190
191
    # The preempted sequences.
    preempted: List[SequenceGroup]
    # Sequences that are swapped out.
    swapped_out: List[SequenceGroup]
    # The blocks to swap out.
192
    blocks_to_swap_out: List[Tuple[int, int]]
193
    # The blocks to copy.
194
    blocks_to_copy: List[Tuple[int, int]]
195
    # The number of slots for lookahead decoding.
196
197
    num_lookahead_slots: int

198
199
200
201
    # Optimization for fast-access to seq_group lists
    decode_seq_groups_list: List[SequenceGroup]
    prefill_seq_groups_list: List[SequenceGroup]

202
    @classmethod
203
204
205
206
    def create_empty(cls) -> "SchedulerRunningOutputs":
        return SchedulerRunningOutputs(
            decode_seq_groups=[],
            prefill_seq_groups=[],
207
208
            preempted=[],
            swapped_out=[],
209
            blocks_to_swap_out=[],
210
            blocks_to_copy=[],
211
            num_lookahead_slots=0,
212
213
            decode_seq_groups_list=[],
            prefill_seq_groups_list=[],
214
215
216
217
218
        )


@dataclass
class SchedulerSwappedInOutputs:
219
220
221
222
223
224
    """The requests that are scheduled from a swap queue.

    Could contain prefill (prefill that's chunked) or decodes.
    """
    # Selected sequences that are going to be swapped in and is in a
    # decoding phase.
225
    decode_seq_groups: List[ScheduledSequenceGroup]
226
227
    # Selected sequences that are going to be swapped in and in a prefill
    # phase. I.e., it means the prefill has been chunked.
228
    prefill_seq_groups: List[ScheduledSequenceGroup]
229
    # The blocks to swap in.
230
    blocks_to_swap_in: List[Tuple[int, int]]
231
    # The blocks to copy.
232
    blocks_to_copy: List[Tuple[int, int]]
233
    # The number of slots for lookahead decoding.
234
    num_lookahead_slots: int
235
236
    # Infeasible sequence groups.
    infeasible_seq_groups: List[SequenceGroup]
237
238
239
240

    @classmethod
    def create_empty(cls) -> "SchedulerSwappedInOutputs":
        return SchedulerSwappedInOutputs(
241
242
            decode_seq_groups=[],
            prefill_seq_groups=[],
243
            blocks_to_swap_in=[],
244
            blocks_to_copy=[],
245
            num_lookahead_slots=0,
246
            infeasible_seq_groups=[],
247
248
249
250
251
        )


@dataclass
class SchedulerPrefillOutputs:
252
253
254
255
256
257
    """The requests that are scheduled from a waiting queue.

    Could contain a fresh prefill requests or preempted requests that need
    to be recomputed from scratch.
    """
    # Selected sequences for prefill.
258
    seq_groups: List[ScheduledSequenceGroup]
259
260
261
262
263
264
265
266
267
268
269
270
271
    # Ignored sequence groups.
    ignored_seq_groups: List[SequenceGroup]
    num_lookahead_slots: int

    @classmethod
    def create_empty(cls) -> "SchedulerPrefillOutputs":
        return SchedulerPrefillOutputs(
            seq_groups=[],
            ignored_seq_groups=[],
            num_lookahead_slots=0,
        )


272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def seq_group_metadata_builder():
    return SequenceGroupMetadata(request_id="",
                                 is_prompt=False,
                                 seq_data={},
                                 sampling_params=None,
                                 block_tables={})


def scheduler_running_outputs_builder():
    return SchedulerRunningOutputs(decode_seq_groups=[],
                                   prefill_seq_groups=[],
                                   preempted=[],
                                   swapped_out=[],
                                   blocks_to_swap_out=[],
                                   blocks_to_copy=[],
                                   num_lookahead_slots=0,
                                   prefill_seq_groups_list=[],
                                   decode_seq_groups_list=[])


def scheduled_seq_group_builder():
293
    return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup),
294
295
                                  token_chunk_size=0)
    # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
296
297


Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
300
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
301
        self,
302
303
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
304
        lora_config: Optional[LoRAConfig],
305
        pipeline_parallel_size: int = 1,
306
        output_proc_callback: Optional[Callable] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
307
    ) -> None:
308
309
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
310
311
312
313
        # Note for LoRA scheduling: the current policy is extremely
        # simple and NOT fair. It can lead to starvation of some
        # LoRAs. This should be improved in the future.
        self.lora_config = lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
314

315
        version = "selfattn"
316
        if (self.scheduler_config.task == "embedding"
317
318
                or self.cache_config.is_attention_free):
            version = "placeholder"
319

320
        BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
321
            version)
322

323
324
325
326
327
328
329
330
        num_gpu_blocks = cache_config.num_gpu_blocks
        if num_gpu_blocks:
            num_gpu_blocks //= pipeline_parallel_size

        num_cpu_blocks = cache_config.num_cpu_blocks
        if num_cpu_blocks:
            num_cpu_blocks //= pipeline_parallel_size

Woosuk Kwon's avatar
Woosuk Kwon committed
331
        # Create the block space manager.
332
        self.block_manager = BlockSpaceManagerImpl(
333
            block_size=self.cache_config.block_size,
334
335
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
336
337
            sliding_window=self.cache_config.sliding_window,
            enable_caching=self.cache_config.enable_prefix_caching)
338

339
        # Sequence groups in the WAITING state.
340
        # Contain new prefill or preempted requests.
341
        self.waiting: Deque[SequenceGroup] = deque()
342
        # Sequence groups in the RUNNING state.
343
        # Contain decode requests.
344
        self.running: Deque[SequenceGroup] = deque()
345
        # Sequence groups in the SWAPPED state.
346
        # Contain decode requests that are swapped out.
347
        self.swapped: Deque[SequenceGroup] = deque()
Mor Zusman's avatar
Mor Zusman committed
348
349
350
        # Sequence groups finished requests ids since last step iteration.
        # It lets the model know that any state associated with these requests
        # can and must be released after the current step.
351
        # This is used to evict the finished requests from the Mamba cache.
Mor Zusman's avatar
Mor Zusman committed
352
        self._finished_requests_ids: List[str] = list()
353
354
355
356
357
358
        # Time at previous scheduling step
        self.prev_time = 0.0
        # Did we schedule a prompt at previous step?
        self.prev_prompt = False
        # Latency of the last prompt step
        self.last_prompt_latency = 0.0
359
360
        # preemption mode, RECOMPUTE or SWAP
        self.user_specified_preemption_mode = scheduler_config.preemption_mode
361

362
363
364
365
366
367
        # The following field is test-only. It is used to inject artificial
        # preemption.
        self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
        self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
                                       if self.enable_artificial_preemption
                                       else 0)
368
        self.num_cumulative_preemption: int = 0
369

370
        # Used to cache python objects
371
372
373
374
375
376
377
378
        self._seq_group_metadata_cache: List[PyObjectCache] = []
        self._scheduler_running_outputs_cache: List[PyObjectCache] = []
        self._scheduled_seq_group_cache: List[PyObjectCache] = []

        # For async output processing, we need to swap cache buffers between
        # iterations. I.e. since the output processing is lagged one step,
        # we cannot reuse the cached objects immediately when the schedule()
        # is called again, but only when schedule() is called the second time.
379
380
        self.output_proc_callback = output_proc_callback
        self.use_async_output_proc = self.output_proc_callback is not None
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
        self.num_cache_iters = 2 if self.use_async_output_proc else 1

        self.cache_id = 0
        for i in range(self.num_cache_iters):
            self._seq_group_metadata_cache.append(
                PyObjectCache(seq_group_metadata_builder))
            self._scheduler_running_outputs_cache.append(
                PyObjectCache(scheduler_running_outputs_builder))
            self._scheduled_seq_group_cache.append(
                PyObjectCache(scheduled_seq_group_builder))

        # For async postprocessor, the extra decode run cannot be done
        # when the request reaches max_model_len. In this case, the request
        # will be stopped during schedule() call and added to this stop list
        # for processing and deallocation by the free_finished_seq_groups()
        self._async_stopped: List[SequenceGroup] = []

    @property
    def next_cache_id(self):
        return (self.cache_id + 1) % self.num_cache_iters
401

402
403
404
405
    @property
    def lora_enabled(self) -> bool:
        return bool(self.lora_config)

406
407
408
409
410
    @property
    def num_decoding_tokens_per_seq(self) -> int:
        """The number of new tokens."""
        return 1

411
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
412
        # Add sequence groups to the waiting queue.
413
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
414

415
416
417
418
419
420
421
422
423
424
    def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
        # Add sequence groups to the running queue.
        # Only for testing purposes.
        self.running.append(seq_group)

    def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
        # Add sequence groups to the swapped queue.
        # Only for testing purposes.
        self.swapped.append(seq_group)

Antoni Baum's avatar
Antoni Baum committed
425
    def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
426
427
428
429
430
431
432
433
434
435
436
437
        """Aborts a sequence group with the given ID.

        Check if the sequence group with the given ID
            is present in any of the state queue.
        If present, remove the sequence group from the state queue.
            Also, if any of the sequences in the sequence group is not finished,
                free the sequence with status `FINISHED_ABORTED`.
        Otherwise, do nothing.

        Args:
            request_id: The ID(s) of the sequence group to abort.
        """
Antoni Baum's avatar
Antoni Baum committed
438
439
440
        if isinstance(request_id, str):
            request_id = (request_id, )
        request_ids = set(request_id)
441
        for state_queue in [self.waiting, self.running, self.swapped]:
ljss's avatar
ljss committed
442
            aborted_groups: List[SequenceGroup] = []
443
444
445
            for seq_group in state_queue:
                if not request_ids:
                    # Using 'break' here may add two extra iterations,
446
                    # but is acceptable to reduce complexity.
447
                    break
Antoni Baum's avatar
Antoni Baum committed
448
                if seq_group.request_id in request_ids:
449
450
                    # Appending aborted group into pending list.
                    aborted_groups.append(seq_group)
Antoni Baum's avatar
Antoni Baum committed
451
                    request_ids.remove(seq_group.request_id)
452
453
454
            for aborted_group in aborted_groups:
                # Remove the sequence group from the state queue.
                state_queue.remove(aborted_group)
455
                # Remove the aborted request from the Mamba cache.
456
                self._finished_requests_ids.append(aborted_group.request_id)
ljss's avatar
ljss committed
457
                for seq in aborted_group.get_seqs():
458
459
460
461
                    if seq.is_finished():
                        continue
                    seq.status = SequenceStatus.FINISHED_ABORTED
                    self.free_seq(seq)
462

463
464
465
466
467
468
469
470
471
472
473
474
475
                self._free_seq_group_cross_attn_blocks(aborted_group)

    def _free_seq_group_cross_attn_blocks(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        """
        Free a sequence group from a cross-attention block table.
        Has no effect on decoder-only models.
        """
        if seq_group.is_encoder_decoder():
            self.block_manager.free_cross(seq_group)

476
    def has_unfinished_seqs(self) -> bool:
477
478
        return len(self.waiting) != 0 or len(self.running) != 0 or len(
            self.swapped) != 0
479

480
481
482
    def get_prefix_cache_hit_rate(self, device: Device) -> float:
        return self.block_manager.get_prefix_cache_hit_rate(device)

483
484
485
    def get_num_unfinished_seq_groups(self) -> int:
        return len(self.waiting) + len(self.running) + len(self.swapped)

Mor Zusman's avatar
Mor Zusman committed
486
487
488
489
490
491
    def get_and_reset_finished_requests_ids(self) -> List[str]:
        """Flushes the list of request ids of previously finished seq_groups."""
        finished_requests_ids = self._finished_requests_ids
        self._finished_requests_ids = list()
        return finished_requests_ids

492
    def _schedule_running(
493
494
495
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
496
        enable_chunking: bool = False,
497
    ) -> SchedulerRunningOutputs:
498
        """Schedule sequence groups that are running.
499

500
        Running queue should include decode and chunked prefill requests.
Woosuk Kwon's avatar
Woosuk Kwon committed
501

502
503
504
505
506
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any decodes are preempted.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any decodes are preempted.
507
508
509
510
511
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
    
512
        Returns:
513
            SchedulerRunningOutputs.
514
        """
515
        ret: SchedulerRunningOutputs = \
516
            self._scheduler_running_outputs_cache[self.cache_id].get_object()
517
518
519
520
521
522
523
524
        ret.blocks_to_swap_out.clear()
        ret.blocks_to_copy.clear()
        ret.decode_seq_groups.clear()
        ret.prefill_seq_groups.clear()
        ret.preempted.clear()
        ret.swapped_out.clear()

        ret.num_lookahead_slots = self._get_num_lookahead_slots(
525
            is_prefill=False, enable_chunking=enable_chunking)
526
527
528
529

        ret.decode_seq_groups_list.clear()
        ret.prefill_seq_groups_list.clear()

530
        # Blocks that need to be swapped or copied before model execution.
531
532
        blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out
        blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy
Woosuk Kwon's avatar
Woosuk Kwon committed
533

534
535
536
537
538
        decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups
        prefill_seq_groups: List[
            ScheduledSequenceGroup] = ret.prefill_seq_groups
        preempted: List[SequenceGroup] = ret.preempted
        swapped_out: List[SequenceGroup] = ret.swapped_out
Woosuk Kwon's avatar
Woosuk Kwon committed
539

540
541
        running_queue = self.running
        assert len(self._async_stopped) == 0
542
543
        while running_queue:
            seq_group = running_queue[0]
544
545
546
            num_running_tokens = self._get_num_new_tokens(
                seq_group, SequenceStatus.RUNNING, enable_chunking, budget)

547
            if num_running_tokens == 0:
548
                # No budget => Stop
549
                break
550
551

            running_queue.popleft()
552
553
554
555
556
557
558
559
560
561

            # With async postprocessor, an extra decode run is done
            # to process the final tokens. The check below avoids this extra
            # decode run when the model max len is reached, in order to avoid
            # a memory overflow.
            if self.use_async_output_proc and seq_group.seqs[0].get_len(
            ) > self.scheduler_config.max_model_len:
                self._async_stopped.append(seq_group)
                continue

562
563
            # NOTE(woosuk): Preemption happens only when there is no available
            # slot to keep all the sequence groups in the RUNNING state.
564
            while not self._can_append_slots(seq_group, enable_chunking):
565
566
                budget.subtract_num_batched_tokens(seq_group.request_id,
                                                   num_running_tokens)
567
                num_running_seqs = seq_group.get_max_num_running_seqs()
568
569
                budget.subtract_num_seqs(seq_group.request_id,
                                         num_running_seqs)
570
571
572

                if (curr_loras is not None and seq_group.lora_int_id > 0
                        and seq_group.lora_int_id in curr_loras):
573
                    curr_loras.remove(seq_group.lora_int_id)
574

575
576
                # Determine victim sequence
                cont_loop = True
577
                if running_queue:
578
                    # Preempt the lowest-priority sequence group.
579
                    victim_seq_group = running_queue.pop()
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
                else:
                    # No other sequence group can be preempted.
                    # Preempt the current sequence group.
                    # Note: This is also where we stop this loop
                    # (since there is nothing else to preempt)
                    victim_seq_group = seq_group
                    cont_loop = False

                # With async postprocessor, before preempting a sequence
                # we need to ensure it has no pending async postprocessor
                do_preempt = True
                if self.use_async_output_proc:
                    assert self.output_proc_callback is not None
                    self.output_proc_callback(
                        request_id=victim_seq_group.request_id)

                    # It may be that the async pending "victim_seq_group"
                    # becomes finished, in which case we simply free it.
                    if victim_seq_group.is_finished():
                        self._free_finished_seq_group(victim_seq_group)
                        do_preempt = False

                # Do preemption
                if do_preempt:
604
605
606
607
608
609
                    preempted_mode = self._preempt(victim_seq_group,
                                                   blocks_to_swap_out)
                    if preempted_mode == PreemptionMode.RECOMPUTE:
                        preempted.append(victim_seq_group)
                    else:
                        swapped_out.append(victim_seq_group)
610
611

                if not cont_loop:
Woosuk Kwon's avatar
Woosuk Kwon committed
612
613
                    break
            else:
614
                self._append_slots(seq_group, blocks_to_copy, enable_chunking)
615
                is_prefill = seq_group.is_prefill()
616
617

                scheduled_seq_group: ScheduledSequenceGroup = \
618
                    self._scheduled_seq_group_cache[self.cache_id].get_object()
619
                scheduled_seq_group.seq_group = seq_group
620
                if is_prefill:
621
622
623
                    scheduled_seq_group.token_chunk_size = num_running_tokens
                    prefill_seq_groups.append(scheduled_seq_group)
                    ret.prefill_seq_groups_list.append(seq_group)
624
                else:
625
626
627
628
                    scheduled_seq_group.token_chunk_size = 1
                    decode_seq_groups.append(scheduled_seq_group)
                    ret.decode_seq_groups_list.append(seq_group)

629
630
                budget.add_num_batched_tokens(seq_group.request_id,
                                              num_running_tokens)
631
632
633
634
635
636
637
                # OPTIMIZATION:  Note that get_max_num_running_seqs is
                # expensive. For the default scheduling chase where
                # enable_chunking is False, num_seqs are updated before running
                # this method, so we don't have to update it again here.
                if enable_chunking:
                    num_running_seqs = seq_group.get_max_num_running_seqs()
                    budget.add_num_seqs(seq_group.request_id, num_running_seqs)
638
639
640
                if curr_loras is not None and seq_group.lora_int_id > 0:
                    curr_loras.add(seq_group.lora_int_id)

641
642
        self._scheduler_running_outputs_cache[self.next_cache_id].reset()
        self._scheduled_seq_group_cache[self.next_cache_id].reset()
643
644

        return ret
645

646
647
648
649
    def _schedule_swapped(
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
650
        enable_chunking: bool = False,
651
    ) -> SchedulerSwappedInOutputs:
652
        """Schedule sequence groups that are swapped out.
653

654
655
656
        It schedules swapped requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.
657

658
659
660
661
662
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are swapped in.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are swapped in.
663
664
665
666
667
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.

668
669
670
671
        Returns:
            SchedulerSwappedInOutputs.
        """
        # Blocks that need to be swapped or copied before model execution.
672
        blocks_to_swap_in: List[Tuple[int, int]] = []
673
        blocks_to_copy: List[Tuple[int, int]] = []
674
675
        decode_seq_groups: List[ScheduledSequenceGroup] = []
        prefill_seq_groups: List[ScheduledSequenceGroup] = []
676
        infeasible_seq_groups: List[SequenceGroup] = []
677

678
679
        swapped_queue = self.swapped

680
        leftover_swapped: Deque[SequenceGroup] = deque()
681
682
683
684
        while swapped_queue:
            seq_group = swapped_queue[0]

            # If the sequence group cannot be swapped in, stop.
685
686
            is_prefill = seq_group.is_prefill()
            alloc_status = self.block_manager.can_swap_in(
687
688
                seq_group,
                self._get_num_lookahead_slots(is_prefill, enable_chunking))
689
            if alloc_status == AllocStatus.LATER:
690
                break
691
692
693
694
695
696
697
698
699
700
            elif alloc_status == AllocStatus.NEVER:
                logger.warning(
                    "Failing the request %s because there's not enough kv "
                    "cache blocks to run the entire sequence.",
                    seq_group.request_id)
                for seq in seq_group.get_seqs():
                    seq.status = SequenceStatus.FINISHED_IGNORED
                infeasible_seq_groups.append(seq_group)
                swapped_queue.popleft()
                continue
701
702
703
704

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
705
706
707
                assert curr_loras is not None
                assert self.lora_config is not None
                if (lora_int_id > 0 and (lora_int_id not in curr_loras)
708
709
710
711
712
713
714
715
716
717
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_swapped.appendleft(seq_group)
                    swapped_queue.popleft()
                    continue

            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
            num_new_seqs = seq_group.get_max_num_running_seqs()
718
719
720
            num_new_tokens = self._get_num_new_tokens(seq_group,
                                                      SequenceStatus.SWAPPED,
                                                      enable_chunking, budget)
721

722
723
724
            if (num_new_tokens == 0
                    or not budget.can_schedule(num_new_tokens=num_new_tokens,
                                               num_new_seqs=num_new_seqs)):
725
726
727
728
729
730
                break

            if lora_int_id > 0 and curr_loras is not None:
                curr_loras.add(lora_int_id)
            swapped_queue.popleft()
            self._swap_in(seq_group, blocks_to_swap_in)
731
            self._append_slots(seq_group, blocks_to_copy, enable_chunking)
732
733
734
735
736
737
738
739
740
741
            is_prefill = seq_group.is_prefill()
            if is_prefill:
                prefill_seq_groups.append(
                    ScheduledSequenceGroup(seq_group,
                                           token_chunk_size=num_new_tokens))
            else:
                decode_seq_groups.append(
                    ScheduledSequenceGroup(seq_group, token_chunk_size=1))
            budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
742
743
744

        swapped_queue.extendleft(leftover_swapped)

745
        return SchedulerSwappedInOutputs(
746
747
            decode_seq_groups=decode_seq_groups,
            prefill_seq_groups=prefill_seq_groups,
748
749
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_copy=blocks_to_copy,
750
            num_lookahead_slots=self._get_num_lookahead_slots(
751
                is_prefill=False, enable_chunking=enable_chunking),
752
753
            infeasible_seq_groups=infeasible_seq_groups,
        )
754

755
    def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
756
757
        if self.scheduler_config.chunked_prefill_enabled and \
                not self.scheduler_config.is_multi_step:
758
759
760
761
762
763
764
765
766
767
768
769
770
            prompt_limit = self.scheduler_config.max_model_len
        else:
            prompt_limit = min(self.scheduler_config.max_model_len,
                               self.scheduler_config.max_num_batched_tokens)

        # Model is fine tuned with long context. Return the fine tuned max_len.
        if (seq_group.lora_request
                and seq_group.lora_request.long_lora_max_len):
            assert prompt_limit <= seq_group.lora_request.long_lora_max_len
            return seq_group.lora_request.long_lora_max_len
        else:
            return prompt_limit

771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
    def _get_priority(self,
                      seq_group: SequenceGroup) -> Tuple[Optional[int], float]:
        """ Get the priority of the sequence group.
        Highest preference to user-defined priority, followed by arrival time.
        Args:
            seq_group: The sequence group input.
        Returns:
            The priority of the sequence group.
        """
        return seq_group.priority, seq_group.arrival_time

    def _schedule_priority_preemption(
        self,
        budget: SchedulingBudget,
    ) -> int:
        """Sorts waiting and running queue. Also, force preempt requests
        from the running queue if their priority is lower.
        Priority-based preemption is used with the priority policy.
        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are scheduled.
        Returns:
            A count of priority-based preemptions.
        """

        waiting_queue = self.waiting

        running_queue = deque(sorted(self.running, key=self._get_priority))

        blocks_to_swap_out: List[Tuple[int, int]] = []
        force_preemption_count = 0

        if waiting_queue:
            seq_group = waiting_queue.popleft()
            num_new_seqs = seq_group.get_max_num_running_seqs()
            num_new_tokens = self._get_num_new_tokens(seq_group,
                                                      SequenceStatus.WAITING,
                                                      False, budget)

            #Only preempt if priority inversion exists
            while running_queue and self._get_priority(
                    running_queue[-1]) > self._get_priority(seq_group):
                #Only preempt if waiting sequence cannot be allocated
                can_allocate = self.block_manager.can_allocate(seq_group)
                if (num_new_tokens and can_allocate == AllocStatus.OK
                        and budget.can_schedule(num_new_tokens=num_new_tokens,
                                                num_new_seqs=num_new_seqs)):
                    break

                #Adjust budget to remove the victim sequence group
                vseq_group = running_queue.pop()
                num_running_tokens = self._get_num_new_tokens(
                    vseq_group, SequenceStatus.RUNNING, False, budget)
                budget.subtract_num_batched_tokens(vseq_group.request_id,
                                                   num_running_tokens)
                num_running_seqs = vseq_group.get_max_num_running_seqs()
                budget.subtract_num_seqs(vseq_group.request_id,
                                         num_running_seqs)

                #Preempt out the victim sequence group
                self._preempt(vseq_group, blocks_to_swap_out,
                              PreemptionMode.RECOMPUTE)
                waiting_queue.appendleft(vseq_group)
                force_preemption_count += 1
            #Put the sequence back into the waiting queue
            waiting_queue.appendleft(seq_group)

        waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))

        self.waiting = waiting_queue
        self.running = running_queue
        return force_preemption_count

844
845
846
847
    def _schedule_prefills(
        self,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
848
        enable_chunking: bool = False,
849
    ) -> SchedulerPrefillOutputs:
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
        """Schedule sequence groups that are in prefill stage.

        Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
        as a new prefill (that starts from beginning -> most recently generated
        tokens).

        It schedules waiting requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.

        Args:
            budget: The scheduling budget. The argument is in-place updated
                when any requests are scheduled.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are scheduled.
865
866
867
868
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
869
870

        Returns:
871
            SchedulerPrefillOutputs.
872
873
        """
        ignored_seq_groups: List[SequenceGroup] = []
874
        seq_groups: List[ScheduledSequenceGroup] = []
875
876

        waiting_queue = self.waiting
877

878
        leftover_waiting_sequences: Deque[SequenceGroup] = deque()
879
880
881
882
883
884
885
        while self._passed_delay(time.time()) and waiting_queue:
            seq_group = waiting_queue[0]

            waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
            assert len(waiting_seqs) == 1, (
                "Waiting sequence group should have only one prompt "
                "sequence.")
886
887
888
889
890
891
892
            num_new_tokens = self._get_num_new_tokens(seq_group,
                                                      SequenceStatus.WAITING,
                                                      enable_chunking, budget)
            if not enable_chunking:
                num_prompt_tokens = waiting_seqs[0].get_len()
                assert num_new_tokens == num_prompt_tokens

893
894
            prompt_limit = self._get_prompt_limit(seq_group)
            if num_new_tokens > prompt_limit:
895
                logger.warning(
896
                    "Input prompt (%d tokens) is too long"
897
                    " and exceeds limit of %d", num_new_tokens, prompt_limit)
898
899
900
901
902
903
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

904
905
906
907
908
            num_lookahead_slots: int = 0
            if self.scheduler_config.is_multi_step and enable_chunking:
                num_lookahead_slots = self._get_num_lookahead_slots(
                    True, enable_chunking)

909
            # If the sequence group cannot be allocated, stop.
910
911
            can_allocate = self.block_manager.can_allocate(
                seq_group, num_lookahead_slots=num_lookahead_slots)
912
913
914
915
            if can_allocate == AllocStatus.LATER:
                break
            elif can_allocate == AllocStatus.NEVER:
                logger.warning(
916
917
918
                    "Input prompt (%d tokens) + lookahead slots (%d) is "
                    "too long and exceeds the capacity of block_manager",
                    num_new_tokens, num_lookahead_slots)
919
920
921
922
923
924
925
926
927
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
928
929
                assert curr_loras is not None
                assert self.lora_config is not None
930
931
932
933
934
935
936
937
938
939
                if (self.lora_enabled and lora_int_id > 0
                        and lora_int_id not in curr_loras
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_waiting_sequences.appendleft(seq_group)
                    waiting_queue.popleft()
                    continue

            num_new_seqs = seq_group.get_max_num_running_seqs()
940
941
942
            if (num_new_tokens == 0
                    or not budget.can_schedule(num_new_tokens=num_new_tokens,
                                               num_new_seqs=num_new_seqs)):
943
944
945
946
947
948
                break

            # Can schedule this request.
            if curr_loras is not None and lora_int_id > 0:
                curr_loras.add(lora_int_id)
            waiting_queue.popleft()
949
            self._allocate_and_set_running(seq_group)
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967

            if enable_chunking and self.scheduler_config.is_multi_step:
                blocks_to_copy: List[Tuple[int, int]] = []
                # init_multi_step_from_lookahead_slots happens in append_slots
                self._append_slots(seq_group, blocks_to_copy, enable_chunking)
                # This assert will trip when a copy-on-write happens. This is
                # not a concern as the very first sequence-group block
                # allocation happens above. Still, we have the assert to
                # catch any edge-cases.
                assert not blocks_to_copy
            else:
                seq_group.init_multi_step_from_lookahead_slots(
                    num_lookahead_slots,
                    num_scheduler_steps=self.scheduler_config.
                    num_scheduler_steps,
                    is_multi_step=self.scheduler_config.is_multi_step,
                    enable_chunking=enable_chunking)

968
969
            seq_groups.append(
                ScheduledSequenceGroup(seq_group=seq_group,
970
971
972
                                       token_chunk_size=num_new_tokens))
            budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
973
974
975
976
977
978

        # Queue requests that couldn't be scheduled.
        waiting_queue.extendleft(leftover_waiting_sequences)
        if len(seq_groups) > 0:
            self.prev_prompt = True

979
        return SchedulerPrefillOutputs(
980
981
            seq_groups=seq_groups,
            ignored_seq_groups=ignored_seq_groups,
982
983
            num_lookahead_slots=self._get_num_lookahead_slots(
                is_prefill=True, enable_chunking=enable_chunking))
984

985
986
    def _schedule_default(self) -> SchedulerOutputs:
        """Schedule queued requests.
987
        
988
        The current policy is designed to optimize the throughput. First,
989
990
991
992
993
994
995
996
997
        it batches as many prefill requests as possible. And it schedules
        decodes. If there's a pressure on GPU memory, decode requests can
        be swapped or preempted.
        """
        # Include running requests to the budget.
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
998
999
1000
1001
1002
        # Make sure we include num running seqs before scheduling prefill,
        # so that we don't schedule beyond max_num_seqs for prefill.
        for seq_group in self.running:
            budget.add_num_seqs(seq_group.request_id,
                                seq_group.get_max_num_running_seqs())
1003
        curr_loras = set(
1004
1005
            seq_group.lora_int_id for seq_group in self.running
            if seq_group.lora_int_id > 0) if self.lora_enabled else None
1006

1007
1008
1009
        prefills = SchedulerPrefillOutputs.create_empty()
        running_scheduled = SchedulerRunningOutputs.create_empty()
        swapped_in = SchedulerSwappedInOutputs.create_empty()
1010
1011
1012

        # If any requests are swapped, prioritized swapped requests.
        if not self.swapped:
1013
1014
1015
            prefills = self._schedule_prefills(budget,
                                               curr_loras,
                                               enable_chunking=False)
1016

1017
1018
1019
1020
        if len(prefills.seq_groups
               ) == 0 and self.scheduler_config.policy == "priority":
            self._schedule_priority_preemption(budget)

1021
        # Don't schedule decodes if prefills are scheduled.
1022
1023
        # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
        # only contains decode requests, not chunked prefills.
1024
        if len(prefills.seq_groups) == 0:
1025
1026
1027
            running_scheduled = self._schedule_running(budget,
                                                       curr_loras,
                                                       enable_chunking=False)
1028

1029
1030
            # If any sequence group is preempted, do not swap in any sequence
            # group. because it means there's no slot for new running requests.
1031
1032
            if len(running_scheduled.preempted) + len(
                    running_scheduled.swapped_out) == 0:
1033
                swapped_in = self._schedule_swapped(budget, curr_loras)
1034
1035
1036
1037
1038
1039

        assert (budget.num_batched_tokens <=
                self.scheduler_config.max_num_batched_tokens)
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
1040
        self.waiting.extendleft(running_scheduled.preempted)
1041
        # Update new running requests.
1042
1043
1044
1045
1046
1047
1048
1049
1050
        if len(prefills.seq_groups) > 0:
            self.running.extend([s.seq_group for s in prefills.seq_groups])

        self.running.extend(running_scheduled.decode_seq_groups_list)

        if len(swapped_in.decode_seq_groups) > 0:
            self.running.extend(
                [s.seq_group for s in swapped_in.decode_seq_groups])

1051
        # Update swapped requests.
1052
        self.swapped.extend(running_scheduled.swapped_out)
1053
1054
        preempted = (len(running_scheduled.preempted) +
                     len(running_scheduled.swapped_out))
1055

1056
1057
1058
1059
        # There should be no prefill from running queue because this policy
        # doesn't allow chunked prefills.
        assert len(running_scheduled.prefill_seq_groups) == 0
        assert len(swapped_in.prefill_seq_groups) == 0
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075

        # Merge lists
        num_prefill_groups = len(prefills.seq_groups)
        if num_prefill_groups > 0:
            scheduled_seq_groups = prefills.seq_groups
            scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
        else:
            scheduled_seq_groups = running_scheduled.decode_seq_groups
        scheduled_seq_groups.extend(swapped_in.decode_seq_groups)

        blocks_to_copy = running_scheduled.blocks_to_copy
        blocks_to_copy.extend(swapped_in.blocks_to_copy)

        ignored_seq_groups = prefills.ignored_seq_groups
        ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)

1076
        return SchedulerOutputs(
1077
1078
            scheduled_seq_groups=scheduled_seq_groups,
            num_prefill_groups=num_prefill_groups,
1079
1080
            num_batched_tokens=budget.num_batched_tokens,
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
1081
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
1082
1083
            blocks_to_copy=blocks_to_copy,
            ignored_seq_groups=ignored_seq_groups,
1084
            num_lookahead_slots=running_scheduled.num_lookahead_slots,
1085
            running_queue_size=len(self.running),
1086
            preempted=preempted,
1087
1088
        )

1089
    def _schedule_chunked_prefill(self) -> SchedulerOutputs:
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
        """Schedule queued requests.
        
        Chunked prefill allows to chunk prefill requests, batch them together
        with decode requests. This policy 1. schedule as many decoding requests
        as possible. 2. schedule chunked prefill requests that are not
        finished. 3. schedule swapped request. 4. schedule new prefill
        requests.

        The policy can sustain the high GPU utilization because it can put
        prefill and decodes requests to the same batch, while it improves
1100
        inter token latency because decodes requests don't need to be blocked
1101
1102
1103
1104
1105
1106
        by prefill requests.
        """
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
1107
        curr_loras: Set[int] = set()
1108

1109
1110
        prefills = SchedulerPrefillOutputs.create_empty()
        swapped_in = SchedulerSwappedInOutputs.create_empty()
1111
1112

        # Decoding should be always scheduled first by fcfs.
1113
1114
1115
        running_scheduled = self._schedule_running(budget,
                                                   curr_loras,
                                                   enable_chunking=True)
1116
1117
1118
1119
1120

        # Schedule swapped out requests.
        # If preemption happens, it means we don't have space for swap-in.
        if len(running_scheduled.preempted) + len(
                running_scheduled.swapped_out) == 0:
1121
            swapped_in = self._schedule_swapped(budget, curr_loras)
1122
1123

        # Schedule new prefills.
1124
1125
1126
        prefills = self._schedule_prefills(budget,
                                           curr_loras,
                                           enable_chunking=True)
1127
1128
1129
1130
1131
1132
1133

        assert (budget.num_batched_tokens <=
                self.scheduler_config.max_num_batched_tokens)
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
        self.waiting.extendleft(running_scheduled.preempted)
1134

1135
        # Update new running requests.
1136
1137
1138
        # By default, vLLM scheduler prioritizes prefills.
        # Once chunked prefill is enabled,
        # the policy is changed to prioritize decode requests.
1139
1140
1141
1142
        self.running.extend(
            [s.seq_group for s in swapped_in.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.prefill_seq_groups])
1143
1144
1145
1146
1147
1148
        self.running.extend(
            [s.seq_group for s in running_scheduled.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in running_scheduled.prefill_seq_groups])
        self.running.extend([s.seq_group for s in prefills.seq_groups])

1149
1150
1151
1152
1153
        # Update swapped requests.
        self.swapped.extend(running_scheduled.swapped_out)
        return SchedulerOutputs(
            scheduled_seq_groups=(prefills.seq_groups +
                                  running_scheduled.prefill_seq_groups +
1154
1155
1156
                                  swapped_in.prefill_seq_groups +
                                  running_scheduled.decode_seq_groups +
                                  swapped_in.decode_seq_groups),
1157
1158
1159
1160
1161
1162
            num_prefill_groups=(len(prefills.seq_groups) +
                                len(swapped_in.prefill_seq_groups) +
                                len(running_scheduled.prefill_seq_groups)),
            num_batched_tokens=budget.num_batched_tokens,
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
1163
1164
            blocks_to_copy=running_scheduled.blocks_to_copy +
            swapped_in.blocks_to_copy,
1165
1166
            ignored_seq_groups=prefills.ignored_seq_groups +
            swapped_in.infeasible_seq_groups,
1167
            num_lookahead_slots=running_scheduled.num_lookahead_slots,
1168
            running_queue_size=len(self.running),
1169
1170
            preempted=(len(running_scheduled.preempted) +
                       len(running_scheduled.swapped_out)),
1171
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
1172

1173
1174
1175
1176
1177
1178
1179
    def _schedule(self) -> SchedulerOutputs:
        """Schedule queued requests."""
        if self.scheduler_config.chunked_prefill_enabled:
            return self._schedule_chunked_prefill()
        else:
            return self._schedule_default()

1180
1181
    def _can_append_slots(self, seq_group: SequenceGroup,
                          enable_chunking: bool) -> bool:
1182
1183
1184
        """Determine whether or not we have enough space in the KV cache to
        continue generation of the sequence group.
        """
1185
1186
1187
1188
1189
1190
1191
        # It is True only for testing case to trigger artificial preemption.
        if (self.enable_artificial_preemption
                and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
                and self.artificial_preempt_cnt > 0):
            self.artificial_preempt_cnt -= 1
            return False

1192
1193
1194
1195
1196
1197
1198
1199
        is_prefill = seq_group.is_prefill()
        num_lookahead_slots = self._get_num_lookahead_slots(
            is_prefill, enable_chunking)

        if is_prefill and num_lookahead_slots > 0:
            # Appending prefill slots only happens multi-step and
            # chunked-prefill are enabled together.
            assert self.scheduler_config.is_multi_step and enable_chunking
1200
1201

        return self.block_manager.can_append_slots(
1202
            seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
1203

1204
    def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
1205
1206
1207
        # async_output_proc is allowed only when we have a single sequence
        # in the sequence group
        no_single_seq = seq_group.sampling_params is None or (
1208
            seq_group.sampling_params.n == 1)
1209
        return no_single_seq
1210
1211
1212
1213

    def schedule(
            self
    ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
1214
1215
1216
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
1217
        scheduler_start_time = time.perf_counter()
1218

1219
        scheduler_outputs: SchedulerOutputs = self._schedule()
1220
        now = time.time()
1221

1222
1223
1224
        if not self.cache_config.enable_prefix_caching:
            common_computed_block_nums = []

1225
        allow_async_output_proc: bool = self.use_async_output_proc
1226

1227
        # Create input data structures.
1228
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
1229
1230
        for i, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
1231
1232
            seq_group = scheduled_seq_group.seq_group
            token_chunk_size = scheduled_seq_group.token_chunk_size
1233
1234
            seq_group.maybe_set_first_scheduled_time(now)

1235
1236
1237
1238
1239
            seq_group_metadata = self._seq_group_metadata_cache[
                self.cache_id].get_object()
            seq_group_metadata.seq_data.clear()
            seq_group_metadata.block_tables.clear()

1240
            # seq_id -> SequenceData
1241
            seq_data: Dict[int, SequenceData] = {}
1242
            # seq_id -> physical block numbers
1243
            block_tables: Dict[int, List[int]] = {}
1244

1245
1246
            if seq_group.is_encoder_decoder():
                # Encoder associated with SequenceGroup
1247
1248
1249
                encoder_seq = seq_group.get_encoder_seq()
                assert encoder_seq is not None
                encoder_seq_data = encoder_seq.data
1250
1251
1252
1253
1254
1255
1256
1257
                # Block table for cross-attention
                # Also managed at SequenceGroup level
                cross_block_table = self.block_manager.get_cross_block_table(
                    seq_group)
            else:
                encoder_seq_data = None
                cross_block_table = None

1258
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
1259
                seq_id = seq.seq_id
1260
                seq_data[seq_id] = seq.data
1261
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
1262
                self.block_manager.access_all_blocks_in_seq(seq, now)
1263

1264
1265
1266
1267
            if self.cache_config.enable_prefix_caching:
                common_computed_block_nums = (
                    self.block_manager.get_common_computed_block_ids(
                        seq_group.get_seqs(status=SequenceStatus.RUNNING)))
1268

1269
            do_sample = True
1270
1271
1272
1273
1274
            is_prompt = seq_group.is_prefill()
            # We should send the metadata to workers when the first prefill
            # is sent. Subsequent requests could be chunked prefill or decode.
            is_first_prefill = False
            if is_prompt:
1275
1276
1277
                seqs = seq_group.get_seqs()
                # Prefill has only 1 sequence.
                assert len(seqs) == 1
1278
1279
                num_computed_tokens = seqs[0].data.get_num_computed_tokens()
                is_first_prefill = num_computed_tokens == 0
1280
1281
1282
1283
1284
                # In the next iteration, all prompt tokens are not computed.
                # It means the prefill is chunked, and we don't need sampling.
                # NOTE: We use get_len instead of get_prompt_len because when
                # a sequence is preempted, prefill includes previous generated
                # output tokens.
1285
                if (token_chunk_size + num_computed_tokens <
1286
1287
1288
                        seqs[0].data.get_len()):
                    do_sample = False

1289
1290
            # It assumes the scheduled_seq_groups is ordered by
            # prefill < decoding.
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
            if is_first_prefill or not self.scheduler_config.send_delta_data:
                seq_group_metadata = SequenceGroupMetadata(
                    request_id=seq_group.request_id,
                    is_prompt=is_prompt,
                    seq_data=seq_data,
                    sampling_params=seq_group.sampling_params,
                    block_tables=block_tables,
                    do_sample=do_sample,
                    pooling_params=seq_group.pooling_params,
                    token_chunk_size=token_chunk_size,
                    lora_request=seq_group.lora_request,
                    computed_block_nums=common_computed_block_nums,
                    encoder_seq_data=encoder_seq_data,
                    cross_block_table=cross_block_table,
                    state=seq_group.state,
                    # `multi_modal_data` will only be present for the 1st comm
                    # between engine and worker.
                    # the subsequent comms can still use delta, but
                    # `multi_modal_data` will be None.
                    multi_modal_data=seq_group.multi_modal_data
                    if scheduler_outputs.num_prefill_groups > 0 else None,
1312
                    mm_processor_kwargs=seq_group.mm_processor_kwargs,
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
                    prompt_adapter_request=seq_group.prompt_adapter_request,
                )
            else:
                # When SPMD mode is enabled, we only send delta data except for
                # the first request to reduce serialization cost.
                seq_data_delta = {}
                for id, data in seq_data.items():
                    seq_data_delta[id] = data.get_delta_and_reset()
                seq_group_metadata = SequenceGroupMetadataDelta(
                    seq_data_delta,
                    seq_group.request_id,
                    block_tables,
                    is_prompt,
                    do_sample=do_sample,
                    token_chunk_size=token_chunk_size,
                    computed_block_nums=common_computed_block_nums,
                )
1330
            seq_group_metadata_list.append(seq_group_metadata)
1331

1332
1333
1334
1335
            if allow_async_output_proc:
                allow_async_output_proc = self._allow_async_output_proc(
                    seq_group)

1336
1337
1338
1339
        # Now that the batch has been created, we can assume all blocks in the
        # batch will have been computed before the next scheduling invocation.
        # This is because the engine assumes that a failure in model execution
        # will crash the vLLM instance / will not retry.
1340
1341
        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            self.block_manager.mark_blocks_as_computed(
1342
1343
                scheduled_seq_group.seq_group,
                scheduled_seq_group.token_chunk_size)
1344

1345
1346
        self._seq_group_metadata_cache[self.next_cache_id].reset()

1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
        scheduler_time = time.perf_counter() - scheduler_start_time
        # Add this to scheduler time to all the sequences that are currently
        # running. This will help estimate if the scheduler is a significant
        # component in the e2e latency.
        for seq_group in self.running:
            if seq_group is not None and seq_group.metrics is not None:
                if seq_group.metrics.scheduler_time is not None:
                    seq_group.metrics.scheduler_time += scheduler_time
                else:
                    seq_group.metrics.scheduler_time = scheduler_time

1358
1359
1360
1361
1362
1363
        # Move to next cache (if exists)
        self.cache_id = self.next_cache_id

        # Return results
        return (seq_group_metadata_list, scheduler_outputs,
                allow_async_output_proc)
1364

1365
1366
    def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
        self.block_manager.fork(parent_seq, child_seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
1367

1368
    def free_seq(self, seq: Sequence) -> None:
1369
        """Free a sequence from a block table."""
1370
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
1371

1372
1373
1374
1375
1376
1377
    def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
        """Free finished seqs in a sequence group."""
        for seq in seq_group.get_seqs():
            if seq.is_finished():
                self.free_seq(seq)

1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
    def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None:
        if seq_group.is_finished():
            # Free cross-attention block table, if it exists
            self._free_seq_group_cross_attn_blocks(seq_group)

            # Add the finished requests to the finished requests list.
            # This list will be used to update the Mamba cache in the
            # next step.
            self._finished_requests_ids.append(seq_group.request_id)

        # Free finished seqs
        self._free_finished_seqs(seq_group)

1391
    def free_finished_seq_groups(self) -> None:
1392
1393
        remaining: Deque[SequenceGroup] = deque()
        for seq_group in self.running:
1394
1395
            self._free_finished_seq_group(seq_group)
            if not seq_group.is_finished():
1396
                remaining.append(seq_group)
1397

1398
        self.running = remaining
Woosuk Kwon's avatar
Woosuk Kwon committed
1399

1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
        # Handle async stopped sequence groups
        # (ones that reached max model len)
        if self._async_stopped:
            for seq_group in self._async_stopped:
                self._free_seq_group_cross_attn_blocks(seq_group)
                self._finished_requests_ids.append(seq_group.request_id)

                # Free finished seqs
                self._free_finished_seqs(seq_group)

            self._async_stopped.clear()

1412
    def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
1413
        self.block_manager.allocate(seq_group)
1414
        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
1415
1416
            seq.status = SequenceStatus.RUNNING

1417
1418
1419
1420
    def _append_slots(self,
                      seq_group: SequenceGroup,
                      blocks_to_copy: List[Tuple[int, int]],
                      enable_chunking: bool = False) -> None:
1421
1422
1423
1424
1425
        """Appends new slots to the sequences in the given sequence group.

        Args:
            seq_group (SequenceGroup): The sequence group containing the
                sequences to append slots to.
1426
1427
1428
1429
1430
            blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
                ints, the first int is the source block index, and the second
                int is the destination block index. This list is updated with
                the new source and destination block indices for the appended
                slots.
1431
            enable_chunking (bool): True if chunked prefill is enabled.
1432
        """
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
        is_prefill: bool = seq_group.is_prefill()
        num_lookahead_slots: int = self._get_num_lookahead_slots(
            is_prefill, enable_chunking)

        seq_group.init_multi_step_from_lookahead_slots(
            num_lookahead_slots,
            num_scheduler_steps=self.scheduler_config.num_scheduler_steps,
            is_multi_step=self.scheduler_config.is_multi_step,
            enable_chunking=enable_chunking)

        seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
        if self.scheduler_config.is_multi_step and enable_chunking:
            # In multi-step chunked-prefill any sequence type can have
            # slots appended.
            seq_status = None

        for seq in seq_group.get_seqs(status=seq_status):
1450
            cows = self.block_manager.append_slots(seq, num_lookahead_slots)
1451
1452
            if len(cows) > 0:
                blocks_to_copy.extend(cows)
1453
1454
1455
1456

    def _preempt(
        self,
        seq_group: SequenceGroup,
1457
        blocks_to_swap_out: List[Tuple[int, int]],
1458
        preemption_mode: Optional[PreemptionMode] = None,
1459
    ) -> PreemptionMode:
1460
1461
1462
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
1463
1464
        # (e.g., beam search), recomputation is not currently supported. In
        # such a case, we use swapping instead.
1465
1466
1467
1468
1469
1470
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
1471
        if self.user_specified_preemption_mode is None:
1472
            if seq_group.get_max_num_running_seqs() == 1:
1473
1474
1475
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
1476

1477
1478
1479
1480
1481
        elif self.user_specified_preemption_mode == "swap":
            preemption_mode = PreemptionMode.SWAP
        else:
            preemption_mode = PreemptionMode.RECOMPUTE

1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
        if self.num_cumulative_preemption % 50 == 0:
            logger.warning(
                "Sequence group %s is preempted by %s mode because there is "
                "not enough KV cache space. This can affect the end-to-end "
                "performance. Increase gpu_memory_utilization or "
                "tensor_parallel_size to provide more KV cache memory. "
                "total_num_cumulative_preemption=%d", seq_group.request_id,
                preemption_mode, self.num_cumulative_preemption + 1)
        self.num_cumulative_preemption += 1

1492
1493
1494
1495
1496
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
1497
            raise AssertionError("Invalid preemption mode.")
1498
        return preemption_mode
1499
1500
1501
1502
1503
1504
1505
1506
1507

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
1508
1509
            self.free_seq(seq)
            seq.reset_state_for_recompute()
1510
1511
1512
1513

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
1514
        blocks_to_swap_out: List[Tuple[int, int]],
1515
1516
1517
1518
1519
1520
    ) -> None:
        self._swap_out(seq_group, blocks_to_swap_out)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
1521
        blocks_to_swap_in: List[Tuple[int, int]],
1522
1523
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
1524
        blocks_to_swap_in.extend(mapping)
1525
1526
1527
1528
1529
1530
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
1531
        blocks_to_swap_out: List[Tuple[int, int]],
1532
    ) -> None:
1533
1534
1535
1536
1537
1538
        if not self.block_manager.can_swap_out(seq_group):
            # FIXME(woosuk): Abort the sequence group instead of aborting the
            # entire engine.
            raise RuntimeError(
                "Aborted due to the lack of CPU swap space. Please increase "
                "the swap space to avoid this error.")
1539
        mapping = self.block_manager.swap_out(seq_group)
1540
        blocks_to_swap_out.extend(mapping)
1541
1542
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED
1543

1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
    def _passed_delay(self, now: float) -> bool:
        if self.prev_prompt:
            self.last_prompt_latency = now - self.prev_time
        self.prev_time, self.prev_prompt = now, False
        # Delay scheduling prompts to let waiting queue fill up
        if self.scheduler_config.delay_factor > 0 and self.waiting:
            earliest_arrival_time = min(
                [e.metrics.arrival_time for e in self.waiting])
            passed_delay = (
                (now - earliest_arrival_time) >
                (self.scheduler_config.delay_factor * self.last_prompt_latency)
                or not self.running)
        else:
            passed_delay = True
        return passed_delay
1559

1560
1561
    def _get_num_lookahead_slots(self, is_prefill: bool,
                                 enable_chunking: bool) -> int:
1562
1563
1564
1565
1566
1567
        """The number of slots to allocate per sequence per step, beyond known
        token ids. Speculative decoding uses these slots to store KV activations
        of tokens which may or may not be accepted.

        Speculative decoding does not yet support prefill, so we do not perform
        lookahead allocation for prefill.
1568
1569
1570
1571

        When chunking is enabled with multi-step, we allocate lookahead slots
        for the prefills for when the prefills turn into decodes in the first
        step.
1572
1573
        """
        if is_prefill:
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
            if self.scheduler_config.is_multi_step and enable_chunking:
                # num_lookahead_slots was introduced in the context of decodes,
                # in Speculative Decoding.
                # When the num_scheduler_steps is 8, say, then the
                # num_lookahead_slots is 7. Meaning, we are doing a 1-step of
                # decode anyways and we wish to do 7 more.
                #
                # "lookaheads" for prefills, is introduced in support for
                # Chunked-Prefill in Multi-Step.
                return self.scheduler_config.num_lookahead_slots + 1
            else:
                return 0
1586
1587

        return self.scheduler_config.num_lookahead_slots
1588
1589
1590

    def _get_num_new_tokens(self, seq_group: SequenceGroup,
                            status: SequenceStatus, enable_chunking: bool,
1591
                            budget: SchedulingBudget) -> int:
1592
1593
1594
1595
1596
1597
1598
        """Get the next new tokens to compute for a given sequence group
            that's in a given `status`.

        The API could chunk the number of tokens to compute based on `budget`
        if `enable_chunking` is True. If a sequence group has multiple
        sequences (e.g., running beam search), it means it is in decoding
        phase, so chunking doesn't happen.
1599
1600

        Returns 0 if the new token cannot be computed due to token budget.
1601
1602
1603
1604
1605
        """
        num_new_tokens = 0
        seqs = seq_group.get_seqs(status=status)
        for seq in seqs:
            num_new_tokens += seq.get_num_new_tokens()
1606
        assert num_new_tokens > 0
1607
1608
1609
        # Chunk if a running request cannot fit in the given budget.
        # If number of seq > 1, it means it is doing beam search
        # in a decode phase. Do not chunk.
1610
        if enable_chunking and len(seqs) == 1:
1611
            remaining_token_budget = budget.remaining_token_budget()
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
            if self.scheduler_config.is_multi_step:
                # The current multi-step + chunked prefill capability does
                # not actually support chunking prompts.
                #
                # Therefore, `num_new_tokens` is computed in the same fashion
                # for both multi-step+chunked-prefill &
                # multi-step+chunked-prefill+APC
                #
                # Prompts with more tokens than the current remaining budget
                # are postponed to future scheduler steps
                if num_new_tokens > self._get_prompt_limit(seq_group):
                    # If the seq_group is in prompt-stage, pass the
                    # num_new_tokens as-is so the caller can ignore
                    # the sequence.
                    pass
                else:
                    num_new_tokens = 0 \
                        if num_new_tokens > remaining_token_budget \
                        else num_new_tokens
            elif self.cache_config.enable_prefix_caching:
1632
                # When prefix caching is enabled, we always allocate
1633
1634
                # the number of new tokens that is dividable by the block
                # size to avoid partial block matching.
1635
                block_size = self.cache_config.block_size
1636
1637
                remainder = budget.token_budget % block_size
                if remainder != 0:
1638
1639
1640
1641
1642
                    raise ValueError("When enabling chunked prefill and "
                                     "prefix caching, max_num_batched_tokens "
                                     "(chunk size) must be dividable by "
                                     "block size, but got chunk_size "
                                     f"({budget.token_budget}) % block_size "
1643
                                     f"({block_size}) = {remainder}")
1644
1645
1646
1647
1648
                if remaining_token_budget < num_new_tokens:
                    num_new_tokens = (remaining_token_budget //
                                      block_size) * block_size
            else:
                num_new_tokens = min(num_new_tokens, remaining_token_budget)
1649
        return num_new_tokens