scheduler.py 27.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from collections import deque
from dataclasses import dataclass
5
6
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
                    Tuple, Union)
7

8
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
9
from vllm.logger import init_logger
10
from vllm.lora.request import LoRARequest
11
from vllm.sampling_params import SamplingParams
12
13
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
                                                compute_encoder_budget)
14
from vllm.v1.core.kv_cache_manager import KVCacheManager
15
16
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
17
18
19
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

20
21
22
23
if TYPE_CHECKING:
    from vllm.multimodal import MultiModalKwargs
    from vllm.multimodal.base import PlaceholderRange

24
25
26
27
28
29
30
31
logger = init_logger(__name__)


class Scheduler:

    def __init__(
        self,
        scheduler_config: SchedulerConfig,
32
        model_config: ModelConfig,
33
34
35
36
37
38
39
        cache_config: CacheConfig,
        lora_config: Optional[LoRAConfig],
    ) -> None:
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
        self.lora_config = lora_config

40
41
42
43
44
45
        # Scheduling constraints.
        self.max_num_running_reqs = self.scheduler_config.max_num_seqs
        self.max_num_scheduled_tokens = \
            self.scheduler_config.max_num_batched_tokens
        self.max_model_len = self.scheduler_config.max_model_len

46
47
        num_gpu_blocks = cache_config.num_gpu_blocks
        assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
48
        # Create the KV cache manager.
49
50
51
        self.kv_cache_manager = KVCacheManager(
            block_size=self.cache_config.block_size,
            num_gpu_blocks=num_gpu_blocks,
52
            max_model_len=self.max_model_len,
53
            sliding_window=self.cache_config.sliding_window,
Cody Yu's avatar
Cody Yu committed
54
            enable_caching=self.cache_config.enable_prefix_caching)
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        self.block_size = self.cache_config.block_size

        # req_id -> Request
        self.requests: Dict[str, Request] = {}
        # Priority queues for requests.
        self.waiting: Deque[Request] = deque()
        self.running: List[Request] = []

        # The request IDs that are finished in between the previous and the
        # current steps. This is used to notify the workers about the finished
        # requests so that they can free the cached states for those requests.
        # This is flushed at the end of each scheduling step.
        self.finished_req_ids: Set[str] = set()

69
        # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
70
        # them at each scheduling step.
71
72
        # Request id -> CachedRequestData
        self._cached_reqs_data: Dict[str, CachedRequestData] = {}
73

74
        # Encoder-related.
75
76
77
78
79
80
81
82
83
        # Calculate encoder cache size if applicable
        # NOTE: For now we use the same budget for both compute and space.
        # This can be changed when we make encoder cache for embedding caching
        # across requests.
        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
        )

84
85
86
        # NOTE(woosuk): Here, "encoder" includes the vision encoder (and
        # projector if needed). Currently, we assume that the encoder also
        # has the Transformer architecture (e.g., ViT).
87
88
89
90
        self.max_num_encoder_input_tokens = encoder_compute_budget
        # NOTE: For the models without encoder (e.g., text-only models),
        # the encoder cache will not be initialized because cache size is 0
        # for these models.
91
        self.encoder_cache_manager = EncoderCacheManager(
92
            cache_size=encoder_cache_size)
93

94
    def schedule(self) -> "SchedulerOutput":
95
96
97
98
99
100
101
        # NOTE(woosuk) on the scheduling algorithm:
        # There's no "decoding phase" nor "prefill phase" in the scheduler.
        # Each request just has the num_computed_tokens and num_tokens,
        # which is equal to len(prompt_token_ids) + len(output_token_ids).
        # At each step, the scheduler tries to assign tokens to the requests
        # so that each request's num_computed_tokens can catch up its
        # num_tokens. This is general enough to cover chunked prefills,
102
103
104
105
106
107
        # prefix caching, and the "jump decoding" optimization in the future.

        scheduled_new_reqs: List[Request] = []
        scheduled_resumed_reqs: List[Request] = []
        scheduled_running_reqs: List[Request] = []
        preempted_reqs: List[Request] = []
108
109
110
111

        req_to_new_block_ids: Dict[str, List[int]] = {}
        num_scheduled_tokens: Dict[str, int] = {}
        token_budget = self.max_num_scheduled_tokens
112
113
114
        # Encoder-related.
        scheduled_encoder_inputs: Dict[str, List[int]] = {}
        encoder_budget = self.max_num_encoder_input_tokens
115
116
117

        # First, schedule the RUNNING requests.
        req_index = 0
118
        while req_index < len(self.running) and token_budget > 0:
119
120
121
122
123
            request = self.running[req_index]
            num_new_tokens = request.num_tokens - request.num_computed_tokens
            num_new_tokens = min(num_new_tokens, token_budget)
            assert num_new_tokens > 0

124
125
126
127
128
129
            # Schedule encoder inputs.
            encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = (
                self._try_schedule_encoder_inputs(request,
                                                  request.num_computed_tokens,
                                                  num_new_tokens,
                                                  encoder_budget))
130
131
132
133
134
135
136
137
            if num_new_tokens == 0:
                # The request cannot be scheduled because the encoder budget
                # or the encoder cache is exhausted.
                # NOTE(woosuk): Here, by doing `continue` instead of `break`,
                # we do not strictly follow the FCFS scheduling policy and
                # allow the lower-priority requests to be scheduled.
                req_index += 1
                continue
138

139
            while True:
140
                new_blocks = self.kv_cache_manager.allocate_slots(
141
                    request, num_new_tokens)
Cody Yu's avatar
Cody Yu committed
142
                if new_blocks is None:
143
144
145
146
147
148
149
150
151
152
153
                    # The request cannot be scheduled.
                    # Preempt the lowest-priority request.
                    preempted_req = self.running.pop()
                    self.kv_cache_manager.free(preempted_req)
                    preempted_req.status = RequestStatus.PREEMPTED
                    preempted_req.num_computed_tokens = 0

                    self.waiting.appendleft(preempted_req)
                    preempted_reqs.append(preempted_req)
                    if preempted_req == request:
                        # No more request to preempt.
154
                        can_schedule = False
155
156
157
                        break
                else:
                    # The request can be scheduled.
158
                    can_schedule = True
159
                    break
160
161
            if not can_schedule:
                break
162
            assert new_blocks is not None
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

            # Schedule the request.
            scheduled_running_reqs.append(request)
            req_to_new_block_ids[request.request_id] = [
                b.block_id for b in new_blocks
            ]
            num_scheduled_tokens[request.request_id] = num_new_tokens
            token_budget -= num_new_tokens
            req_index += 1

            # Encoder-related.
            if encoder_inputs_to_schedule:
                scheduled_encoder_inputs[request.request_id] = (
                    encoder_inputs_to_schedule)
                # Allocate the encoder cache.
                for i in encoder_inputs_to_schedule:
                    self.encoder_cache_manager.allocate(request, i)
                encoder_budget = new_encoder_budget
181

182
183
184
185
186
187
188
189
        # Record the LoRAs in scheduled_running_reqs
        requested_loras: Set[int] = set()
        if self.lora_config:
            requested_loras = set(
                req.lora_request.lora_int_id for req in scheduled_running_reqs
                if req.lora_request and req.lora_request.lora_int_id > 0)
            assert len(requested_loras) <= self.lora_config.max_loras

190
191
        # Next, schedule the WAITING requests.
        if not preempted_reqs:
192
            while self.waiting and token_budget > 0:
193
194
195
196
                if len(self.running) == self.max_num_running_reqs:
                    break

                request = self.waiting[0]
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

                # Check that adding the request still respects the max_loras
                # constraint.
                if self.lora_config and request.lora_request:
                    req_lora_id = request.lora_request.lora_int_id
                    if len(requested_loras) == self.lora_config.max_loras and (
                            req_lora_id not in requested_loras):
                        # Cannot schedule.
                        # TODO (varun): This means all the other requests in
                        # the WAITING queue will be blocked by this request,
                        # even if,
                        # 1. these other requests do not use LoRA, or,
                        # 2. these other requests use the already requested
                        # LoRAs.
                        # This is too conservative and could be optimized.
                        break

214
                # Get already-cached tokens.
215
216
                computed_blocks, num_computed_tokens = \
                    self.kv_cache_manager.get_computed_blocks(request)
217
218
219
220
221
                # Number of tokens to be scheduled.
                # We use `request.num_tokens` instead of
                # `request.num_prompt_tokens` to consider the resumed requests,
                # which have output tokens.
                num_new_tokens = request.num_tokens - num_computed_tokens
Cody Yu's avatar
Cody Yu committed
222
                if num_new_tokens == 0:
223
                    # This happens when prompt length is divisible by the block
Cody Yu's avatar
Cody Yu committed
224
                    # size and all blocks are cached. Now we force to recompute
225
226
227
228
229
230
231
                    # the last block. Note that we have to re-compute an entire
                    # block because allocate_slots() assumes num_computed_tokens
                    # is always a multiple of the block size. This limitation
                    # can potentially be removed in the future to slightly
                    # improve the performance.
                    num_computed_tokens -= self.block_size
                    num_new_tokens = self.block_size
Cody Yu's avatar
Cody Yu committed
232
                    computed_blocks.pop()
233
234
                num_new_tokens = min(num_new_tokens, token_budget)
                assert num_new_tokens > 0
235
236
237
238
239
240
241
242
243
244

                # Schedule encoder inputs.
                (encoder_inputs_to_schedule, num_new_tokens,
                 new_encoder_budget) = self._try_schedule_encoder_inputs(
                     request, num_computed_tokens, num_new_tokens,
                     encoder_budget)
                if num_new_tokens == 0:
                    # The request cannot be scheduled.
                    break

Cody Yu's avatar
Cody Yu committed
245
246
247
                new_blocks = self.kv_cache_manager.allocate_slots(
                    request, num_new_tokens, computed_blocks)
                if new_blocks is None:
248
249
250
251
252
253
254
255
256
257
258
259
260
                    # The request cannot be scheduled.
                    break

                self.waiting.popleft()
                self.running.append(request)
                if request.status == RequestStatus.WAITING:
                    scheduled_new_reqs.append(request)
                elif request.status == RequestStatus.PREEMPTED:
                    scheduled_resumed_reqs.append(request)
                else:
                    raise RuntimeError(
                        f"Invalid request status: {request.status}")

261
262
                if self.lora_config and request.lora_request:
                    requested_loras.add(request.lora_request.lora_int_id)
Cody Yu's avatar
Cody Yu committed
263
264
265
                req_to_new_block_ids[request.request_id] = [
                    b.block_id for b in computed_blocks + new_blocks
                ]
266
267
268
                num_scheduled_tokens[request.request_id] = num_new_tokens
                token_budget -= num_new_tokens
                request.status = RequestStatus.RUNNING
269
270
271
272
273
274
275
276
277
278
                request.num_computed_tokens = num_computed_tokens

                # Encoder-related.
                if encoder_inputs_to_schedule:
                    scheduled_encoder_inputs[request.request_id] = (
                        encoder_inputs_to_schedule)
                    # Allocate the encoder cache.
                    for i in encoder_inputs_to_schedule:
                        self.encoder_cache_manager.allocate(request, i)
                    encoder_budget = new_encoder_budget
279
280
281
282
283
284

        # Check if the scheduling constraints are satisfied.
        total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
        assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
        assert token_budget >= 0
        assert len(self.running) <= self.max_num_running_reqs
285
286
287
        # Since some requests in the RUNNING queue may not be scheduled in
        # this step, the total number of scheduled requests can be smaller than
        # len(self.running).
288
        assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
289
                len(scheduled_running_reqs) <= len(self.running))
290

291
292
        # Get the longest common prefix among all requests in the running queue.
        # This can be potentially used for cascade attention.
293
        num_common_prefix_blocks = 0
294
295
296
297
298
299
        if self.running:
            any_request = self.running[0]
            num_common_prefix_blocks = (
                self.kv_cache_manager.get_num_common_prefix_blocks(
                    any_request, len(self.running)))

300
301
302
303
304
305
306
307
        # Construct the scheduler output.
        new_reqs_data = [
            NewRequestData.from_request(req,
                                        req_to_new_block_ids[req.request_id],
                                        req.num_computed_tokens)
            for req in scheduled_new_reqs
        ]
        resumed_reqs_data = [
308
309
310
311
312
313
            self._make_cached_request_data(
                req,
                req_to_new_block_ids[req.request_id],
                req.num_computed_tokens,
                resumed_from_preemption=True,
            ) for req in scheduled_resumed_reqs
314
315
        ]
        running_reqs_data = [
316
317
318
319
320
321
            self._make_cached_request_data(
                req,
                req_to_new_block_ids[req.request_id],
                req.num_computed_tokens,
                resumed_from_preemption=False,
            ) for req in scheduled_running_reqs
322
323
324
        ]
        scheduler_output = SchedulerOutput(
            scheduled_new_reqs=new_reqs_data,
325
            scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
326
327
            num_scheduled_tokens=num_scheduled_tokens,
            total_num_scheduled_tokens=total_num_scheduled_tokens,
328
            scheduled_encoder_inputs=scheduled_encoder_inputs,
329
            num_common_prefix_blocks=num_common_prefix_blocks,
330
331
332
333
334
            # finished_req_ids is an existing state in the scheduler,
            # instead of being newly scheduled in this step.
            # It contains the request IDs that are finished in between
            # the previous and the current steps.
            finished_req_ids=self.finished_req_ids,
335
            free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
336
337
338
339
340
        )

        self.finished_req_ids = set()
        return scheduler_output

341
    def _make_cached_request_data(
342
343
344
345
        self,
        request: Request,
        new_block_ids: List[int],
        num_computed_tokens: int,
346
347
348
        resumed_from_preemption: bool,
    ) -> "CachedRequestData":
        # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
349
        # them at each scheduling step.
350
351
352
        if request.request_id in self._cached_reqs_data:
            req_data = self._cached_reqs_data[request.request_id]
            req_data.resumed_from_preemption = resumed_from_preemption
353
354
355
            req_data.new_block_ids = new_block_ids
            req_data.num_computed_tokens = num_computed_tokens
        else:
356
357
358
359
360
            req_data = CachedRequestData.from_request(request,
                                                      resumed_from_preemption,
                                                      new_block_ids,
                                                      num_computed_tokens)
            self._cached_reqs_data[request.request_id] = req_data
361
362
        return req_data

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    def _try_schedule_encoder_inputs(
        self,
        request: Request,
        num_computed_tokens: int,
        num_new_tokens: int,
        encoder_budget: int,
    ) -> Tuple[List[int], int, int]:
        """
        Determine which encoder inputs need to be scheduled in the current step,
        and update `num_new_tokens` and encoder token budget accordingly.

        An encoder input will be scheduled if:
        - Its output tokens overlap with the range of tokens being computed
        in this step, i.e.,
        [num_computed_tokens, num_computed_tokens + num_new_tokens).
        - It is not already computed and stored in the encoder cache.
        - There is sufficient encoder token budget to process it.
        - The encoder cache has space to store it.

        If an encoder input cannot be scheduled due to cache or budget
        limitations, the method adjusts `num_new_tokens` to schedule only the
        decoder tokens up to just before the unschedulable encoder input.
        """
        if not request.has_encoder_inputs():
            return [], num_new_tokens, encoder_budget

        encoder_inputs_to_schedule: List[int] = []
        mm_positions = request.mm_positions
        assert mm_positions is not None
        assert len(mm_positions) > 0
        for i, pos_info in enumerate(mm_positions):
            start_pos = pos_info["offset"]
            num_encoder_tokens = pos_info["length"]

            # The encoder output is needed if the two ranges overlap:
            # [num_computed_tokens, num_computed_tokens + num_new_tokens) and
            # [start_pos, start_pos + num_encoder_tokens)
            if start_pos >= num_computed_tokens + num_new_tokens:
                # The encoder input is not needed in this step.
                break
            if start_pos + num_encoder_tokens <= num_computed_tokens:
                # The encoder input is already computed and stored
                # in the decoder's KV cache.
                continue

            if self.encoder_cache_manager.has_cache(request, i):
                # The encoder input is already computed and cached.
                continue
411
412
413
414
415
            if (not self.encoder_cache_manager.can_allocate(request, i)
                    or num_encoder_tokens > encoder_budget):
                # The encoder cache is full or the encoder budget is exhausted.
                # NOTE(woosuk): We assume that the encoder input tokens should
                # be processed altogether, as the encoder usually uses
416
                # bidirectional attention.
417
418
419
420
421
422
423
424
425
426
                if num_computed_tokens < start_pos:
                    # We only schedule the decoder tokens just before the
                    # encoder input.
                    num_new_tokens = start_pos - num_computed_tokens
                else:
                    # Because of prefix caching, num_computed_tokens is greater
                    # than start_pos even though its encoder input is not
                    # available. In this case, we can't schedule any token for
                    # the request in this step.
                    num_new_tokens = 0
427
428
429
430
431
432
                break

            encoder_budget -= num_encoder_tokens
            encoder_inputs_to_schedule.append(i)
        return encoder_inputs_to_schedule, num_new_tokens, encoder_budget

433
434
435
436
    def update_from_output(
        self,
        scheduler_output: "SchedulerOutput",
        model_runner_output: "ModelRunnerOutput",
437
    ) -> EngineCoreOutputs:
438
        # NOTE(woosuk): This method doesn't consider speculative decoding.
439
        sampled_token_ids = model_runner_output.sampled_token_ids
440
441
        num_scheduled_tokens = scheduler_output.num_scheduled_tokens
        new_running: List[Request] = []
442
        outputs: List[EngineCoreOutput] = []
443
444
445
446

        # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
        # loop can be a performance bottleneck. We should do our best to avoid
        # expensive operations inside the loop.
447
448
        for request in self.running:
            req_id = request.request_id
449
450
451
452
453
454
455
            num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
            if num_tokens_scheduled == 0:
                # The request was not scheduled in this step.
                new_running.append(request)
                continue

            request.num_computed_tokens += num_tokens_scheduled
456
457
458
459
            # When the request's num_computed_tokens catches up its num_tokens,
            # the request generates output tokens. Otherwise, we ignore the
            # sampler output for the request.
            assert request.num_computed_tokens <= request.num_tokens
460
461
462

            cached_encoder_input_ids = (
                self.encoder_cache_manager.get_cached_input_ids(request))
463
464
465
466
467
468
469
470
            # OPTIMIZATION: Avoid list(set) if the set is empty.
            if cached_encoder_input_ids:
                for input_id in list(cached_encoder_input_ids):
                    start_pos = request.mm_positions[input_id]["offset"]
                    num_tokens = request.mm_positions[input_id]["length"]
                    if start_pos + num_tokens <= request.num_computed_tokens:
                        # The encoder output is already processed and stored
                        # in the decoder's KV cache.
471
472
                        self.encoder_cache_manager.free_encoder_input(
                            request, input_id)
473

474
475
476
477
478
            if request.num_computed_tokens == request.num_tokens:
                req_index = model_runner_output.req_id_to_index[req_id]
                # NOTE(woosuk): Currently, we assume that each request
                # generates at most one token at each step.
                token_id = sampled_token_ids[req_index]
479
                request.append_output_token_ids(token_id)
480
                num_new_tokens = 1
481
482
                # TODO: Update the KV cache manager for prefix caching.

483
                # Check for stop and update request state.
484
                # This must be called before we make the EngineCoreOutput.
485
                stopped = self._check_stop(request)
486
487
                if stopped:
                    self._free_request(request)
488
489
490
491
492
493
494
495

                # Add EngineCoreOutput for this Request.
                output = EngineCoreOutput(
                    request_id=req_id,
                    new_token_ids=request.output_token_ids[-num_new_tokens:],
                    finished=request.is_finished(),
                    finish_reason=request.get_finished_reason(),
                    stop_reason=request.stop_reason)
496
                outputs.append(output)
497
498

                # Breakout of the loop.
499
500
501
502
503
                if stopped:
                    continue

            new_running.append(request)
        self.running = new_running
504
505
506
507
        return EngineCoreOutputs(
            outputs=outputs,
            scheduler_stats=self.make_stats(),
        )
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562

    def _check_stop(self, request: Request) -> bool:
        if (request.num_tokens >= self.max_model_len
                or request.num_output_tokens >= request.max_tokens):
            request.status = RequestStatus.FINISHED_LENGTH_CAPPED
            return True

        sampling_params = request.sampling_params
        last_token_id = request.output_token_ids[-1]
        if (not sampling_params.ignore_eos
                and last_token_id == request.eos_token_id):
            request.status = RequestStatus.FINISHED_STOPPED
            return True

        if last_token_id in (sampling_params.stop_token_ids or ()):
            request.status = RequestStatus.FINISHED_STOPPED
            request.stop_reason = last_token_id
            return True
        return False

    def add_request(self, request: Request) -> None:
        self.waiting.append(request)
        self.requests[request.request_id] = request

    def finish_requests(
        self,
        request_ids: Union[str, Iterable[str]],
        finished_status: RequestStatus,
    ) -> None:
        """Handles the finish signal from outside the scheduler.

        For example, the API server can abort a request when the client
        disconnects.
        """
        assert RequestStatus.is_finished(finished_status)
        if isinstance(request_ids, str):
            request_ids = (request_ids, )
        request_ids = set(request_ids)

        for req_id in request_ids:
            request = self.requests.get(req_id)
            if request is None:
                # Invalid request ID.
                continue

            if request.status == RequestStatus.RUNNING:
                self.running.remove(request)
            else:
                self.waiting.remove(request)
            request.status = finished_status
            self._free_request(request)

    def _free_request(self, request: Request) -> None:
        assert request.is_finished()
        self.kv_cache_manager.free(request)
563
        self.encoder_cache_manager.free(request)
564
        self._cached_reqs_data.pop(request.request_id, None)
565
566
567
568
569
570
571
572
573
        del self.requests[request.request_id]
        self.finished_req_ids.add(request.request_id)

    def get_num_unfinished_requests(self) -> int:
        return len(self.waiting) + len(self.running)

    def has_unfinished_requests(self) -> bool:
        return self.get_num_unfinished_requests() > 0

574
575
576
    def reset_prefix_cache(self) -> bool:
        return self.kv_cache_manager.reset_prefix_cache()

577
578
579
580
    def make_stats(self) -> SchedulerStats:
        return SchedulerStats(
            num_running_reqs=len(self.running),
            num_waiting_reqs=len(self.waiting),
581
            gpu_cache_usage=self.kv_cache_manager.usage,
582
583
        )

584
585
586
587
588
589
590

@dataclass
class NewRequestData:

    req_id: str
    prompt_token_ids: List[int]
    prompt: Optional[str]
591
    mm_inputs: List["MultiModalKwargs"]
592
    mm_hashes: List[str]
593
    mm_positions: List["PlaceholderRange"]
594
595
596
    sampling_params: SamplingParams
    block_ids: List[int]
    num_computed_tokens: int
597
    lora_request: Optional[LoRARequest]
598
599
600
601
602
603
604
605
606
607

    @classmethod
    def from_request(
        cls,
        request: Request,
        block_ids: List[int],
        num_computed_tokens: int,
    ) -> "NewRequestData":
        return cls(
            req_id=request.request_id,
608
609
610
            prompt_token_ids=request.prompt_token_ids,
            prompt=request.prompt,
            mm_inputs=request.mm_inputs,
611
            mm_hashes=request.mm_hashes,
612
            mm_positions=request.mm_positions,
613
614
615
            sampling_params=request.sampling_params,
            block_ids=block_ids,
            num_computed_tokens=num_computed_tokens,
616
            lora_request=request.lora_request,
617
618
619
620
        )


@dataclass
621
class CachedRequestData:
622
623

    req_id: str
624
625
626
627
    # If resumed_from_preemption is False, new_block_ids will be appended to
    # the request's block IDs. If True, new_block_ids will be used as the
    # request's block IDs instead of appending to the existing block IDs.
    resumed_from_preemption: bool
628
629
630
631
632
633
634
    new_block_ids: List[int]
    num_computed_tokens: int

    @classmethod
    def from_request(
        cls,
        request: Request,
635
        resumed_from_preemption: bool,
636
637
        new_block_ids: List[int],
        num_computed_tokens: int,
638
    ) -> "CachedRequestData":
639
640
        return cls(
            req_id=request.request_id,
641
            resumed_from_preemption=resumed_from_preemption,
642
643
644
645
646
647
648
649
650
            new_block_ids=new_block_ids,
            num_computed_tokens=num_computed_tokens,
        )


@dataclass
class SchedulerOutput:

    scheduled_new_reqs: List[NewRequestData]
651
    scheduled_cached_reqs: List[CachedRequestData]
652
653
654

    num_scheduled_tokens: Dict[str, int]
    total_num_scheduled_tokens: int
655
    scheduled_encoder_inputs: Dict[str, List[int]]
656
    num_common_prefix_blocks: int
657
658

    finished_req_ids: Set[str]
659
    free_encoder_input_ids: List[Tuple[str, int]]