scheduler.py 46.9 KB
Newer Older
1
2
import enum
import time
3
from collections import deque
4
from dataclasses import dataclass, field
5
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
6

7
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
8
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
9
from vllm.core.policy import Policy, PolicyFactory
Woosuk Kwon's avatar
Woosuk Kwon committed
10
from vllm.logger import init_logger
11
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
13
                           SequenceGroupMetadata, SequenceStatus)
14
from vllm.utils import merge_dicts
Woosuk Kwon's avatar
Woosuk Kwon committed
15

Woosuk Kwon's avatar
Woosuk Kwon committed
16
logger = init_logger(__name__)
17

Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
20
21
22
23
24
25
26
27
28
29
30
31
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


32
33
@dataclass
class SchedulingBudget:
34
35
36
37
38
39
40
41
42
    """The available slots for scheduling.

    TODO(sang): Right now, the budget is request_id-aware meaning it can ignore
    budget update from the same request_id. It is because in normal scheduling
    path, we update RUNNING num_seqs ahead of time, meaning it could be
    updated more than once when scheduling RUNNING requests. Since this won't
    happen if we only have chunked prefill scheduling, we can remove this
    feature from the API when chunked prefill is enabled by default.
    """
43
44
    token_budget: int
    max_num_seqs: int
45
46
47
48
    _requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set)
    _requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set)
    _num_batched_tokens: int = 0
    _num_curr_seqs: int = 0
49
50

    def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
51
52
        assert num_new_tokens != 0
        assert num_new_seqs != 0
53
54
55
        return (self.num_batched_tokens + num_new_tokens <= self.token_budget
                and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    def remaining_token_budget(self):
        return self.token_budget - self.num_batched_tokens

    def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
        if req_id in self._requeset_ids_num_batched_tokens:
            return

        self._requeset_ids_num_batched_tokens.add(req_id)
        self._num_batched_tokens += num_batched_tokens

    def subtract_num_batched_tokens(self, req_id: str,
                                    num_batched_tokens: int):
        if req_id in self._requeset_ids_num_batched_tokens:
            self._requeset_ids_num_batched_tokens.remove(req_id)
            self._num_batched_tokens -= num_batched_tokens

    def add_num_seqs(self, req_id: str, num_curr_seqs: int):
        if req_id in self._requeset_ids_num_curr_seqs:
            return

        self._requeset_ids_num_curr_seqs.add(req_id)
        self._num_curr_seqs += num_curr_seqs

    def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
        if req_id in self._requeset_ids_num_curr_seqs:
            self._requeset_ids_num_curr_seqs.remove(req_id)
            self._num_curr_seqs -= num_curr_seqs

    @property
    def num_batched_tokens(self):
        return self._num_batched_tokens

    @property
    def num_curr_seqs(self):
        return self._num_curr_seqs

92

93
94
95
96
97
98
99
100
101
102
@dataclass
class ScheduledSequenceGroup:
    # A sequence group that's scheduled.
    seq_group: SequenceGroup
    # The total chunk size (number of tokens) to process for next iteration.
    # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
    # chunked, it can be smaller than that.
    token_chunk_size: int


103
@dataclass
104
class SchedulerOutputs:
105
    """The scheduling decision made from a scheduler."""
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    # Scheduled sequence groups.
    scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
    # Number of prefill groups scheduled.
    num_prefill_groups: int
    # Total number of batched tokens.
    num_batched_tokens: int
    # Blocks to swap in. Dict of CPU -> GPU block number.
    blocks_to_swap_in: Dict[int, int]
    # Blocks to swap out. Dict of GPU -> CPU block number.
    blocks_to_swap_out: Dict[int, int]
    # Blocks to copy. Source to a list of dest blocks.
    blocks_to_copy: Dict[int, List[int]]
    # Sequence groups that are going to be ignored.
    ignored_seq_groups: List[SequenceGroup]
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int

    def __post_init__(self):
124
        # Swap in and swap out should never happen at the same time.
125
        assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
126

127
        self.num_loras: int = len(self.lora_requests)
128
129
130
        if self.num_loras > 0:
            self._sort_by_lora_ids()

131
    def is_empty(self) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
132
133
134
        # NOTE: We do not consider the ignored sequence groups.
        return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
                and not self.blocks_to_swap_out and not self.blocks_to_copy)
135

136
    def _sort_by_lora_ids(self) -> bool:
137
138
139
        self.scheduled_seq_groups = sorted(
            self.scheduled_seq_groups,
            key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
140
141
142

    @property
    def lora_requests(self) -> Set[LoRARequest]:
143
        return {g.seq_group.lora_request for g in self.scheduled_seq_groups}
144

145

146
@dataclass
147
148
149
150
151
152
153
154
155
156
157
class SchedulerRunningOutputs:
    """The requests that are scheduled from a running queue.

    Could contain prefill (prefill that's chunked) or decodes. If there's not
    enough memory, it can be preempted (for recompute) or swapped out.
    """
    # Selected sequences that are running and in a decoding phase.
    decode_seq_groups: List[SequenceGroup]
    # Selected sequences that are running and in a prefill phase.
    # I.e., it means the prefill has been chunked.
    prefill_seq_groups: List[SequenceGroup]
158
159
160
161
162
163
164
165
    # The preempted sequences.
    preempted: List[SequenceGroup]
    # Sequences that are swapped out.
    swapped_out: List[SequenceGroup]
    # The blocks to swap out.
    blocks_to_swap_out: Dict[int, int]
    # The blocks to copy.
    blocks_to_copy: Dict[int, List[int]]
166
    # The number of slots for lookahead decoding.
167
168
169
    num_lookahead_slots: int

    @classmethod
170
171
172
173
    def create_empty(cls) -> "SchedulerRunningOutputs":
        return SchedulerRunningOutputs(
            decode_seq_groups=[],
            prefill_seq_groups=[],
174
175
176
177
178
179
180
181
182
183
            preempted=[],
            swapped_out=[],
            blocks_to_swap_out={},
            blocks_to_copy={},
            num_lookahead_slots=0,
        )


@dataclass
class SchedulerSwappedInOutputs:
184
185
186
187
188
189
190
191
192
193
    """The requests that are scheduled from a swap queue.

    Could contain prefill (prefill that's chunked) or decodes.
    """
    # Selected sequences that are going to be swapped in and is in a
    # decoding phase.
    decode_seq_groups: List[SequenceGroup]
    # Selected sequences that are going to be swapped in and in a prefill
    # phase. I.e., it means the prefill has been chunked.
    prefill_seq_groups: List[SequenceGroup]
194
195
196
197
    # The blocks to swap in.
    blocks_to_swap_in: Dict[int, int]
    # The blocks to copy.
    blocks_to_copy: Dict[int, List[int]]
198
    # The number of slots for lookahead decoding.
199
200
201
202
203
    num_lookahead_slots: int

    @classmethod
    def create_empty(cls) -> "SchedulerSwappedInOutputs":
        return SchedulerSwappedInOutputs(
204
205
            decode_seq_groups=[],
            prefill_seq_groups=[],
206
207
208
209
210
211
212
213
            blocks_to_swap_in={},
            blocks_to_copy={},
            num_lookahead_slots=0,
        )


@dataclass
class SchedulerPrefillOutputs:
214
215
216
217
218
219
    """The requests that are scheduled from a waiting queue.

    Could contain a fresh prefill requests or preempted requests that need
    to be recomputed from scratch.
    """
    # Selected sequences for prefill.
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    seq_groups: List[SequenceGroup]
    # Ignored sequence groups.
    ignored_seq_groups: List[SequenceGroup]
    num_lookahead_slots: int

    @classmethod
    def create_empty(cls) -> "SchedulerPrefillOutputs":
        return SchedulerPrefillOutputs(
            seq_groups=[],
            ignored_seq_groups=[],
            num_lookahead_slots=0,
        )


Woosuk Kwon's avatar
Woosuk Kwon committed
234
235
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
236
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
237
        self,
238
239
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
240
        lora_config: Optional[LoRAConfig],
Woosuk Kwon's avatar
Woosuk Kwon committed
241
    ) -> None:
242
243
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
244
245
246
247
        # Note for LoRA scheduling: the current policy is extremely
        # simple and NOT fair. It can lead to starvation of some
        # LoRAs. This should be improved in the future.
        self.lora_config = lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
248

249
250
251
252
253
254
        if self.scheduler_config.chunked_prefill_enabled:
            self.prompt_limit = self.scheduler_config.max_model_len
        else:
            self.prompt_limit = min(
                self.scheduler_config.max_model_len,
                self.scheduler_config.max_num_batched_tokens)
255
256
257
258
259

        BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
            version="v2" if self.scheduler_config.
            use_v2_block_manager else "v1")

Woosuk Kwon's avatar
Woosuk Kwon committed
260
        # Create the block space manager.
261
        self.block_manager = BlockSpaceManagerImpl(
262
263
264
            block_size=self.cache_config.block_size,
            num_gpu_blocks=self.cache_config.num_gpu_blocks,
            num_cpu_blocks=self.cache_config.num_cpu_blocks,
265
266
            sliding_window=self.cache_config.sliding_window,
            enable_caching=self.cache_config.enable_prefix_caching)
267

268
        # Sequence groups in the WAITING state.
269
        # Contain new prefill or preempted requests.
270
        self.waiting: Deque[SequenceGroup] = deque()
271
        # Sequence groups in the RUNNING state.
272
        # Contain decode requests.
273
        self.running: Deque[SequenceGroup] = deque()
274
        # Sequence groups in the SWAPPED state.
275
        # Contain decode requests that are swapped out.
276
        self.swapped: Deque[SequenceGroup] = deque()
Woosuk Kwon's avatar
Woosuk Kwon committed
277

278
279
280
281
282
283
284
        # Time at previous scheduling step
        self.prev_time = 0.0
        # Did we schedule a prompt at previous step?
        self.prev_prompt = False
        # Latency of the last prompt step
        self.last_prompt_latency = 0.0

285
286
287
288
    @property
    def lora_enabled(self) -> bool:
        return bool(self.lora_config)

289
290
291
292
293
    @property
    def num_decoding_tokens_per_seq(self) -> int:
        """The number of new tokens."""
        return 1

294
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
295
        # Add sequence groups to the waiting queue.
296
        logger.debug(f"add_seq_group {seq_group.request_id}")
297
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
298

Antoni Baum's avatar
Antoni Baum committed
299
    def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
300
301
302
303
304
305
306
307
308
309
310
311
        """Aborts a sequence group with the given ID.

        Check if the sequence group with the given ID
            is present in any of the state queue.
        If present, remove the sequence group from the state queue.
            Also, if any of the sequences in the sequence group is not finished,
                free the sequence with status `FINISHED_ABORTED`.
        Otherwise, do nothing.

        Args:
            request_id: The ID(s) of the sequence group to abort.
        """
Antoni Baum's avatar
Antoni Baum committed
312
313
314
        if isinstance(request_id, str):
            request_id = (request_id, )
        request_ids = set(request_id)
315
        for state_queue in [self.waiting, self.running, self.swapped]:
ljss's avatar
ljss committed
316
            aborted_groups: List[SequenceGroup] = []
317
318
319
320
321
            for seq_group in state_queue:
                if not request_ids:
                    # Using 'break' here may add two extra iterations,
                    # but is acceptable to reduce complexity .
                    break
Antoni Baum's avatar
Antoni Baum committed
322
                if seq_group.request_id in request_ids:
323
324
                    # Appending aborted group into pending list.
                    aborted_groups.append(seq_group)
Antoni Baum's avatar
Antoni Baum committed
325
                    request_ids.remove(seq_group.request_id)
326
327
328
            for aborted_group in aborted_groups:
                # Remove the sequence group from the state queue.
                state_queue.remove(aborted_group)
ljss's avatar
ljss committed
329
                for seq in aborted_group.get_seqs():
330
331
332
333
                    if seq.is_finished():
                        continue
                    seq.status = SequenceStatus.FINISHED_ABORTED
                    self.free_seq(seq)
334

335
336
337
    def has_unfinished_seqs(self) -> bool:
        return self.waiting or self.running or self.swapped

338
339
340
    def get_num_unfinished_seq_groups(self) -> int:
        return len(self.waiting) + len(self.running) + len(self.swapped)

341
    def _schedule_running(
342
343
344
345
346
        self,
        running_queue: deque,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
        policy: Policy,
347
348
349
        enable_chunking: bool = False,
    ) -> Tuple[deque, SchedulerRunningOutputs]:
        """Schedule sequence groups that are running.
350

351
        Running queue should include decode and chunked prefill requests.
Woosuk Kwon's avatar
Woosuk Kwon committed
352

353
354
355
356
357
358
359
360
        Args:
            running_queue: The queue that contains running requests (i.e.,
                decodes). The given arguments are NOT in-place modified.
            budget: The scheduling budget. The argument is in-place updated
                when any decodes are preempted.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any decodes are preempted.
            policy: The sorting policy to sort running_queue.
361
362
363
364
365
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
    
366
367
        Returns:
            A tuple of remaining running queue (should be always 0) after
368
            scheduling and SchedulerRunningOutputs.
369
370
371
372
        """
        # Blocks that need to be swapped or copied before model execution.
        blocks_to_swap_out: Dict[int, int] = {}
        blocks_to_copy: Dict[int, List[int]] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
373

374
375
        decode_seq_groups: List[ScheduledSequenceGroup] = []
        prefill_seq_groups: List[ScheduledSequenceGroup] = []
376
377
        preempted: List[SequenceGroup] = []
        swapped_out: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
378
379
380

        # NOTE(woosuk): Preemption happens only when there is no available slot
        # to keep all the sequence groups in the RUNNING state.
381
382
        # In this case, the policy is responsible for deciding which sequence
        # groups to preempt.
383
384
        now = time.time()
        running_queue = policy.sort_by_priority(now, running_queue)
385

386
387
        while running_queue:
            seq_group = running_queue[0]
388
389
390
391
392
393
            num_running_tokens = self._get_num_new_tokens(
                seq_group, SequenceStatus.RUNNING, enable_chunking, budget)

            # We can have up to 1 running prefill at any given time in running
            # queue, which means we can guarantee chunk size is at least 1.
            assert num_running_tokens != 0
394
395
396
            num_running_seqs = seq_group.get_max_num_running_seqs()

            running_queue.popleft()
397
            while not self._can_append_slots(seq_group):
398
399
400
401
                budget.subtract_num_batched_tokens(seq_group.request_id,
                                                   num_running_tokens)
                budget.subtract_num_seqs(seq_group.request_id,
                                         num_running_seqs)
402
403
404
405
                if curr_loras is not None and seq_group.lora_int_id > 0:
                    curr_loras.pop(seq_group.lora_int_id)

                if running_queue:
406
                    # Preempt the lowest-priority sequence groups.
407
408
409
410
411
412
413
                    victim_seq_group = running_queue.pop()
                    preempted_mode = self._preempt(victim_seq_group,
                                                   blocks_to_swap_out)
                    if preempted_mode == PreemptionMode.RECOMPUTE:
                        preempted.append(victim_seq_group)
                    else:
                        swapped_out.append(victim_seq_group)
414
415
416
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
417
418
419
420
421
422
                    preempted_mode = self._preempt(seq_group,
                                                   blocks_to_swap_out)
                    if preempted_mode == PreemptionMode.RECOMPUTE:
                        preempted.append(seq_group)
                    else:
                        swapped_out.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
423
424
                    break
            else:
425
                logger.debug(f"append slot for {seq_group}")
426
                self._append_slots(seq_group, blocks_to_copy)
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
                is_prefill = seq_group.is_prefill()
                if is_prefill:
                    prefill_seq_groups.append(
                        ScheduledSequenceGroup(
                            seq_group=seq_group,
                            token_chunk_size=num_running_tokens))
                else:
                    decode_seq_groups.append(
                        ScheduledSequenceGroup(seq_group=seq_group,
                                               token_chunk_size=1))
                budget.add_num_batched_tokens(seq_group.request_id,
                                              num_running_tokens)
                budget.add_num_seqs(seq_group.request_id, num_running_seqs)
                if curr_loras is not None and seq_group.lora_int_id > 0:
                    curr_loras.add(seq_group.lora_int_id)

443
444
445
        # Make sure all queues are updated.
        assert len(running_queue) == 0

446
447
448
        return running_queue, SchedulerRunningOutputs(
            decode_seq_groups=decode_seq_groups,
            prefill_seq_groups=prefill_seq_groups,
449
450
451
452
453
454
            preempted=preempted,
            swapped_out=swapped_out,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
            num_lookahead_slots=self._get_num_lookahead_slots(
                is_prefill=False))
455

456
457
458
459
460
461
    def _schedule_swapped(
        self,
        swapped_queue: deque,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
        policy: Policy,
462
        enable_chunking: bool = False,
463
464
    ) -> Tuple[deque, SchedulerSwappedInOutputs]:
        """Schedule sequence groups that are swapped out.
465

466
467
468
        It schedules swapped requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.
469

470
471
472
473
474
475
476
477
        Args:
            swapped_queue: The queue that contains swapped out requests.
                The given arguments are NOT in-place modified.
            budget: The scheduling budget. The argument is in-place updated
                when any requests are swapped in.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are swapped in.
            policy: The sorting policy to sort swapped_queue.
478
479
480
481
482
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.

483
484
485
486
487
488
489
        Returns:
            A tuple of remaining swapped_queue after scheduling and
            SchedulerSwappedInOutputs.
        """
        # Blocks that need to be swapped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_copy: Dict[int, List[int]] = {}
490
491
        decode_seq_groups: List[ScheduledSequenceGroup] = []
        prefill_seq_groups: List[ScheduledSequenceGroup] = []
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
        now = time.time()
        swapped_queue = policy.sort_by_priority(now, swapped_queue)

        leftover_swapped = deque()
        while swapped_queue:
            seq_group = swapped_queue[0]

            # If the sequence group cannot be swapped in, stop.
            if not self.block_manager.can_swap_in(seq_group):
                break

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
                if (lora_int_id > 0 and lora_int_id not in curr_loras
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_swapped.appendleft(seq_group)
                    swapped_queue.popleft()
                    continue

            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
            num_new_seqs = seq_group.get_max_num_running_seqs()
517
518
519
            num_new_tokens = self._get_num_new_tokens(seq_group,
                                                      SequenceStatus.SWAPPED,
                                                      enable_chunking, budget)
520

521
522
523
            if (num_new_tokens == 0
                    or not budget.can_schedule(num_new_tokens=num_new_tokens,
                                               num_new_seqs=num_new_seqs)):
524
525
526
527
528
529
530
                break

            if lora_int_id > 0 and curr_loras is not None:
                curr_loras.add(lora_int_id)
            swapped_queue.popleft()
            self._swap_in(seq_group, blocks_to_swap_in)
            self._append_slots(seq_group, blocks_to_copy)
531
532
533
534
535
536
537
538
539
540
541
            is_prefill = seq_group.is_prefill()
            if is_prefill:
                prefill_seq_groups.append(
                    ScheduledSequenceGroup(seq_group,
                                           token_chunk_size=num_new_tokens))
            else:
                assert num_new_tokens == 1
                decode_seq_groups.append(
                    ScheduledSequenceGroup(seq_group, token_chunk_size=1))
            budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
542
543
544
545

        swapped_queue.extendleft(leftover_swapped)

        return swapped_queue, SchedulerSwappedInOutputs(
546
547
            decode_seq_groups=decode_seq_groups,
            prefill_seq_groups=prefill_seq_groups,
548
549
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_copy=blocks_to_copy,
550
            num_lookahead_slots=self._get_num_lookahead_slots(
551
552
553
554
555
556
557
                is_prefill=False))

    def _schedule_prefills(
        self,
        waiting_queue: deque,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
558
        enable_chunking: bool = False,
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    ) -> Tuple[deque, SchedulerPrefillOutputs]:
        """Schedule sequence groups that are in prefill stage.

        Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
        as a new prefill (that starts from beginning -> most recently generated
        tokens).

        It schedules waiting requests as long as it fits `budget` and
        curr_loras <= max_lora from the scheduling config. The input arguments
        `budget` and `curr_loras` are updated based on scheduled seq_groups.

        Args:
            waiting_queue: The queue that contains prefill requests.
                The given arguments are NOT in-place modified.
            budget: The scheduling budget. The argument is in-place updated
                when any requests are scheduled.
            curr_loras: Currently batched lora request ids. The argument is
                in-place updated when any requests are scheduled.
577
578
579
580
            enable_chunking: If True, seq group can be chunked and only a
                chunked number of tokens are scheduled  if
                `budget.num_batched_tokens` has not enough capacity to schedule
                all tokens.
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599

        Returns:
            A tuple of remaining waiting_queue after scheduling and
            SchedulerSwappedInOutputs.
        """
        ignored_seq_groups: List[SequenceGroup] = []
        seq_groups: List[SequenceGroup] = []
        # We don't sort waiting queue because we assume it is sorted.
        # Copy the queue so that the input queue is not modified.
        waiting_queue = deque([s for s in waiting_queue])

        leftover_waiting_sequences = deque()
        while self._passed_delay(time.time()) and waiting_queue:
            seq_group = waiting_queue[0]

            waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
            assert len(waiting_seqs) == 1, (
                "Waiting sequence group should have only one prompt "
                "sequence.")
600
601
602
603
604
605
606
607
            num_new_tokens = self._get_num_new_tokens(seq_group,
                                                      SequenceStatus.WAITING,
                                                      enable_chunking, budget)
            if not enable_chunking:
                num_prompt_tokens = waiting_seqs[0].get_len()
                assert num_new_tokens == num_prompt_tokens

            if num_new_tokens > self.prompt_limit:
608
                logger.warning(
609
                    f"Input prompt ({num_new_tokens} tokens) is too long"
610
611
612
613
614
615
616
617
618
619
620
621
622
                    f" and exceeds limit of {self.prompt_limit}")
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

            # If the sequence group cannot be allocated, stop.
            can_allocate = self.block_manager.can_allocate(seq_group)
            if can_allocate == AllocStatus.LATER:
                break
            elif can_allocate == AllocStatus.NEVER:
                logger.warning(
623
                    f"Input prompt ({num_new_tokens} tokens) is too long"
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
                    f" and exceeds the capacity of block_manager")
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
                ignored_seq_groups.append(seq_group)
                waiting_queue.popleft()
                continue

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
                if (self.lora_enabled and lora_int_id > 0
                        and lora_int_id not in curr_loras
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_waiting_sequences.appendleft(seq_group)
                    waiting_queue.popleft()
                    continue

            num_new_seqs = seq_group.get_max_num_running_seqs()
644
645
646
            if (num_new_tokens == 0
                    or not budget.can_schedule(num_new_tokens=num_new_tokens,
                                               num_new_seqs=num_new_seqs)):
647
648
649
650
651
652
                break

            # Can schedule this request.
            if curr_loras is not None and lora_int_id > 0:
                curr_loras.add(lora_int_id)
            waiting_queue.popleft()
653
            self._allocate_and_set_running(seq_group, num_new_tokens)
654
655
            seq_groups.append(
                ScheduledSequenceGroup(seq_group=seq_group,
656
657
658
                                       token_chunk_size=num_new_tokens))
            budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
            budget.add_num_seqs(seq_group.request_id, num_new_seqs)
659
660
661
662
663
664
665
666
667
668
669

        # Queue requests that couldn't be scheduled.
        waiting_queue.extendleft(leftover_waiting_sequences)
        if len(seq_groups) > 0:
            self.prev_prompt = True

        return waiting_queue, SchedulerPrefillOutputs(
            seq_groups=seq_groups,
            ignored_seq_groups=ignored_seq_groups,
            num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))

670
671
    def _schedule_default(self) -> SchedulerOutputs:
        """Schedule queued requests.
672
673
674
675
676
677
678
679
680
681
682
        
        The current policy is designed to opimimize the throughput. First,
        it batches as many prefill requests as possible. And it schedules
        decodes. If there's a pressure on GPU memory, decode requests can
        be swapped or preempted.
        """
        # Include running requests to the budget.
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
683
684
685
686
687
        # Make sure we include num running seqs before scheduling prefill,
        # so that we don't schedule beyond max_num_seqs for prefill.
        for seq_group in self.running:
            budget.add_num_seqs(seq_group.request_id,
                                seq_group.get_max_num_running_seqs())
688
689
690
691
692
693
        curr_loras = set(
            seq_group.lora_int_id
            for seq_group in self.running) if self.lora_enabled else None

        remaining_waiting, prefills = (self.waiting,
                                       SchedulerPrefillOutputs.create_empty())
694
695
        remaining_running, running_scheduled = (
            self.running, SchedulerRunningOutputs.create_empty())
696
697
698
699
700
701
        remaining_swapped, swapped_in = (
            self.swapped, SchedulerSwappedInOutputs.create_empty())

        # If any requests are swapped, prioritized swapped requests.
        if not self.swapped:
            remaining_waiting, prefills = self._schedule_prefills(
702
                self.waiting, budget, curr_loras, enable_chunking=False)
703

704
        fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
705
        # Don't schedule decodes if prefills are scheduled.
706
707
        # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
        # only contains decode requests, not chunked prefills.
708
        if len(prefills.seq_groups) == 0:
709
710
711
712
713
714
715
            remaining_running, running_scheduled = self._schedule_running(
                self.running,
                budget,
                curr_loras,
                fcfs_policy,
                enable_chunking=False)

716
717
            # If any sequence group is preempted, do not swap in any sequence
            # group. because it means there's no slot for new running requests.
718
719
            if len(running_scheduled.preempted) + len(
                    running_scheduled.swapped_out) == 0:
720
                remaining_swapped, swapped_in = self._schedule_swapped(
721
                    self.swapped, budget, curr_loras, fcfs_policy)
722
723
724
725
726
727
728

        assert (budget.num_batched_tokens <=
                self.scheduler_config.max_num_batched_tokens)
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
        self.waiting = remaining_waiting
729
        self.waiting.extendleft(running_scheduled.preempted)
730
731
732
        # Update new running requests.
        self.running = remaining_running
        self.running.extend([s.seq_group for s in prefills.seq_groups])
733
734
735
736
        self.running.extend(
            [s.seq_group for s in running_scheduled.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.decode_seq_groups])
737
738
        # Update swapped requests.
        self.swapped = remaining_swapped
739
        self.swapped.extend(running_scheduled.swapped_out)
740

741
742
743
744
        # There should be no prefill from running queue because this policy
        # doesn't allow chunked prefills.
        assert len(running_scheduled.prefill_seq_groups) == 0
        assert len(swapped_in.prefill_seq_groups) == 0
745
        return SchedulerOutputs(
746
747
748
            scheduled_seq_groups=(prefills.seq_groups +
                                  running_scheduled.decode_seq_groups +
                                  swapped_in.decode_seq_groups),
749
750
751
            num_prefill_groups=len(prefills.seq_groups),
            num_batched_tokens=budget.num_batched_tokens,
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
            blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
                                       swapped_in.blocks_to_copy),
            ignored_seq_groups=prefills.ignored_seq_groups,
            num_lookahead_slots=(prefills.num_lookahead_slots +
                                 running_scheduled.num_lookahead_slots +
                                 swapped_in.num_lookahead_slots),
        )

    def _schedule_chunked_prefill(self):
        """Schedule queued requests.
        
        Chunked prefill allows to chunk prefill requests, batch them together
        with decode requests. This policy 1. schedule as many decoding requests
        as possible. 2. schedule chunked prefill requests that are not
        finished. 3. schedule swapped request. 4. schedule new prefill
        requests.

        The policy can sustain the high GPU utilization because it can put
        prefill and decodes requests to the same batch, while it improves
        inter token latency because decodes requests don't need to blocked
        by prefill requests.
        """
        budget = SchedulingBudget(
            token_budget=self.scheduler_config.max_num_batched_tokens,
            max_num_seqs=self.scheduler_config.max_num_seqs,
        )
        curr_loras = set()

        remaining_waiting, prefills = (self.waiting,
                                       SchedulerPrefillOutputs.create_empty())
        remaining_running, running_scheduled = (
            self.running, SchedulerRunningOutputs.create_empty())
        remaining_swapped, swapped_in = (
            self.swapped, SchedulerSwappedInOutputs.create_empty())

        # Decoding should be always scheduled first by fcfs.
        fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
        remaining_running, running_scheduled = self._schedule_running(
            self.running,
            budget,
            curr_loras,
            fcfs_policy,
            enable_chunking=True)

        # Schedule swapped out requests.
        # If preemption happens, it means we don't have space for swap-in.
        if len(running_scheduled.preempted) + len(
                running_scheduled.swapped_out) == 0:
            remaining_swapped, swapped_in = self._schedule_swapped(
                self.swapped, budget, curr_loras, fcfs_policy)

        # Schedule new prefills.
        remaining_waiting, prefills = self._schedule_prefills(
            self.waiting, budget, curr_loras, enable_chunking=True)

        assert (budget.num_batched_tokens <=
                self.scheduler_config.max_num_batched_tokens)
        assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

        # Update waiting requests.
        self.waiting = remaining_waiting
        self.waiting.extendleft(running_scheduled.preempted)
        # Update new running requests.
        self.running = remaining_running
        self.running.extend([s.seq_group for s in prefills.seq_groups])
        self.running.extend(
            [s.seq_group for s in running_scheduled.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in running_scheduled.prefill_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.decode_seq_groups])
        self.running.extend(
            [s.seq_group for s in swapped_in.prefill_seq_groups])
        # Update swapped requests.
        self.swapped = remaining_swapped
        self.swapped.extend(running_scheduled.swapped_out)

        return SchedulerOutputs(
            scheduled_seq_groups=(prefills.seq_groups +
                                  running_scheduled.decode_seq_groups +
                                  running_scheduled.prefill_seq_groups +
                                  swapped_in.decode_seq_groups +
                                  swapped_in.prefill_seq_groups),
            num_prefill_groups=(len(prefills.seq_groups) +
                                len(swapped_in.prefill_seq_groups) +
                                len(running_scheduled.prefill_seq_groups)),
            num_batched_tokens=budget.num_batched_tokens,
            blocks_to_swap_in=swapped_in.blocks_to_swap_in,
            blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
            blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
843
844
845
                                       swapped_in.blocks_to_copy),
            ignored_seq_groups=prefills.ignored_seq_groups,
            num_lookahead_slots=(prefills.num_lookahead_slots +
846
                                 running_scheduled.num_lookahead_slots +
847
                                 swapped_in.num_lookahead_slots),
848
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
849

850
851
852
853
854
855
856
    def _schedule(self) -> SchedulerOutputs:
        """Schedule queued requests."""
        if self.scheduler_config.chunked_prefill_enabled:
            return self._schedule_chunked_prefill()
        else:
            return self._schedule_default()

857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
    def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
        """Determine whether or not we have enough space in the KV cache to
        continue generation of the sequence group.
        """
        # Appending slots only occurs in decoding.
        is_prefill = False

        return self.block_manager.can_append_slots(
            seq_group=seq_group,
            num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
        )

    def _can_swap_in(self, seq_group: SequenceGroup) -> bool:
        # Swapping in is considered decode.
        is_prefill = False

        return self.block_manager.can_swap_in(
            seq_group=seq_group,
            num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
878
    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
879
880
881
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
Woosuk Kwon's avatar
Woosuk Kwon committed
882
        scheduler_outputs = self._schedule()
883
        now = time.time()
884
885

        # Create input data structures.
886
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
887
888
        for i, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
889
890
            seq_group = scheduled_seq_group.seq_group
            token_chunk_size = scheduled_seq_group.token_chunk_size
891
892
            seq_group.maybe_set_first_scheduled_time(now)

893
            # seq_id -> SequenceData
Light Lin's avatar
Light Lin committed
894
            seq_data: Dict[int, SequenceData] = {}
895
            # seq_id -> physical block numbers
896
            block_tables: Dict[int, List[int]] = {}
897

898
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
899
                seq_id = seq.seq_id
900
                seq_data[seq_id] = seq.data
901
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
902
                self.block_manager.access_all_blocks_in_seq(seq, now)
903

904
905
906
907
            common_computed_block_nums = (
                self.block_manager.get_common_computed_block_ids(
                    seq_group.get_seqs(status=SequenceStatus.RUNNING)))

908
909
910
            # It assumes the scheduled_seq_groups is ordered by
            # prefill < decoding.
            is_prompt = i < scheduler_outputs.num_prefill_groups
911
            seq_group_metadata = SequenceGroupMetadata(
912
                request_id=seq_group.request_id,
913
                is_prompt=is_prompt,
914
                seq_data=seq_data,
915
                sampling_params=seq_group.sampling_params,
916
                block_tables=block_tables,
917
                token_chunk_size=token_chunk_size,
918
                lora_request=seq_group.lora_request,
919
                computed_block_nums=common_computed_block_nums,
Nick Hill's avatar
Nick Hill committed
920
                state=seq_group.state,
921
922
923
924
925
                # `multi_modal_data` will only be present for the 1st comm
                # between engine and worker.
                # the subsequent comms can still use delta, but
                # `multi_modal_data` will be None.
                multi_modal_data=seq_group.multi_modal_data
926
                if scheduler_outputs.num_prefill_groups > 0 else None,
927
            )
928
            seq_group_metadata_list.append(seq_group_metadata)
929
930
931
932
933

        # Now that the batch has been created, we can assume all blocks in the
        # batch will have been computed before the next scheduling invocation.
        # This is because the engine assumes that a failure in model execution
        # will crash the vLLM instance / will not retry.
934
935
936
        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            self.block_manager.mark_blocks_as_computed(
                scheduled_seq_group.seq_group)
937

Woosuk Kwon's avatar
Woosuk Kwon committed
938
        return seq_group_metadata_list, scheduler_outputs
939

940
941
    def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
        self.block_manager.fork(parent_seq, child_seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
942

943
    def free_seq(self, seq: Sequence) -> None:
944
        """Free a sequence from a block table."""
945
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
946

947
    def free_finished_seq_groups(self) -> None:
948
949
        self.running = deque(seq_group for seq_group in self.running
                             if not seq_group.is_finished())
Woosuk Kwon's avatar
Woosuk Kwon committed
950

951
952
    def _allocate_and_set_running(self, seq_group: SequenceGroup,
                                  num_new_tokens: int) -> None:
953
        self.block_manager.allocate(seq_group)
954
        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
955
956
            seq.status = SequenceStatus.RUNNING

957
    def _append_slots(
958
959
960
961
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, List[int]],
    ) -> None:
962
963
964
965
966
967
968
969
970
971
972
973
        """Appends new slots to the sequences in the given sequence group.

        Args:
            seq_group (SequenceGroup): The sequence group containing the
                sequences to append slots to.
            blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source
                block indices to lists of destination block indices. This
                dictionary is updated with the new source and destination block
                indices for the appended slots.
        """
        num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)

974
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
975
976
977
978
979
980
            cows = self.block_manager.append_slots(seq, num_lookahead_slots)

            for src, dests in cows.items():
                if src not in blocks_to_copy:
                    blocks_to_copy[src] = []
                blocks_to_copy[src].extend(dests)
981
982
983
984
985
986

    def _preempt(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
        preemption_mode: Optional[PreemptionMode] = None,
987
    ) -> PreemptionMode:
988
989
990
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
991
992
        # (e.g., beam search), recomputation is not currently supported. In
        # such a case, we use swapping instead.
993
994
995
996
997
998
999
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
        if preemption_mode is None:
1000
            if seq_group.get_max_num_running_seqs() == 1:
1001
1002
1003
1004
1005
1006
1007
1008
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
1009
            raise AssertionError("Invalid preemption mode.")
1010
        return preemption_mode
1011
1012
1013
1014
1015
1016
1017
1018
1019

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
1020
1021
            self.free_seq(seq)
            seq.reset_state_for_recompute()
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        self._swap_out(seq_group, blocks_to_swap_out)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
        blocks_to_swap_in.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
1045
1046
1047
1048
1049
1050
        if not self.block_manager.can_swap_out(seq_group):
            # FIXME(woosuk): Abort the sequence group instead of aborting the
            # entire engine.
            raise RuntimeError(
                "Aborted due to the lack of CPU swap space. Please increase "
                "the swap space to avoid this error.")
1051
1052
1053
1054
        mapping = self.block_manager.swap_out(seq_group)
        blocks_to_swap_out.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED
1055

1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
    def _passed_delay(self, now: float) -> bool:
        if self.prev_prompt:
            self.last_prompt_latency = now - self.prev_time
        self.prev_time, self.prev_prompt = now, False
        # Delay scheduling prompts to let waiting queue fill up
        if self.scheduler_config.delay_factor > 0 and self.waiting:
            earliest_arrival_time = min(
                [e.metrics.arrival_time for e in self.waiting])
            passed_delay = (
                (now - earliest_arrival_time) >
                (self.scheduler_config.delay_factor * self.last_prompt_latency)
                or not self.running)
        else:
            passed_delay = True
        return passed_delay
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083

    def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
        """The number of slots to allocate per sequence per step, beyond known
        token ids. Speculative decoding uses these slots to store KV activations
        of tokens which may or may not be accepted.

        Speculative decoding does not yet support prefill, so we do not perform
        lookahead allocation for prefill.
        """
        if is_prefill:
            return 0

        return self.scheduler_config.num_lookahead_slots
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106

    def _get_num_new_tokens(self, seq_group: SequenceGroup,
                            status: SequenceStatus, enable_chunking: bool,
                            budget: SchedulingBudget) -> Tuple[int, bool]:
        """Get the next new tokens to compute for a given sequence group
            that's in a given `status`.

        The API could chunk the number of tokens to compute based on `budget`
        if `enable_chunking` is True. If a sequence group has multiple
        sequences (e.g., running beam search), it means it is in decoding
        phase, so chunking doesn't happen.
        """
        num_new_tokens = 0
        seqs = seq_group.get_seqs(status=status)
        for seq in seqs:
            num_new_tokens += seq.get_num_new_tokens()
        # Chunk if a running request cannot fit in.
        # If number of seq > 1, it means it is doing beam search in a
        # decode phase. Do not chunk in that case.
        if enable_chunking and len(seqs) == 1:
            num_new_tokens = min(num_new_tokens,
                                 budget.remaining_token_budget())
        return num_new_tokens