scheduler.py 31.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
from collections import deque
5
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
6

7
8
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
                         SpeculativeConfig)
9
from vllm.logger import init_logger
10
11
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
                                                compute_encoder_budget)
12
from vllm.v1.core.kv_cache_manager import KVCacheManager
13
14
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
                                           SchedulerOutput)
15
16
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
                            EngineCoreOutput, EngineCoreOutputs)
17
from vllm.v1.metrics.stats import SchedulerStats
18
19
20
21
22
23
24
25
26
27
28
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

logger = init_logger(__name__)


class Scheduler:

    def __init__(
        self,
        scheduler_config: SchedulerConfig,
29
        model_config: ModelConfig,
30
31
        cache_config: CacheConfig,
        lora_config: Optional[LoRAConfig],
32
        speculative_config: Optional[SpeculativeConfig],
33
        log_stats: bool,
34
35
36
37
    ) -> None:
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
        self.lora_config = lora_config
38
        self.speculative_config = speculative_config
39
        self.log_stats = log_stats
40

41
42
43
44
45
46
        # Scheduling constraints.
        self.max_num_running_reqs = self.scheduler_config.max_num_seqs
        self.max_num_scheduled_tokens = \
            self.scheduler_config.max_num_batched_tokens
        self.max_model_len = self.scheduler_config.max_model_len

47
48
        num_gpu_blocks = cache_config.num_gpu_blocks
        assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
49
        # Create the KV cache manager.
50
51
52
        self.kv_cache_manager = KVCacheManager(
            block_size=self.cache_config.block_size,
            num_gpu_blocks=num_gpu_blocks,
53
            max_model_len=self.max_model_len,
54
            sliding_window=self.cache_config.sliding_window,
55
56
            enable_caching=self.cache_config.enable_prefix_caching,
            log_stats=self.log_stats)
57
58
59
60
61
62
63
        self.block_size = self.cache_config.block_size

        # req_id -> Request
        self.requests: Dict[str, Request] = {}
        # Priority queues for requests.
        self.waiting: Deque[Request] = deque()
        self.running: List[Request] = []
64
65
66
        # The requests that have been scheduled and are being executed
        # by the executor.
        self.scheduled_req_ids: Set[str] = set()
67
68
69
70
71
72
73

        # The request IDs that are finished in between the previous and the
        # current steps. This is used to notify the workers about the finished
        # requests so that they can free the cached states for those requests.
        # This is flushed at the end of each scheduling step.
        self.finished_req_ids: Set[str] = set()

74
        # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
75
        # them at each scheduling step.
76
77
        # Request id -> CachedRequestData
        self._cached_reqs_data: Dict[str, CachedRequestData] = {}
78

79
        # Encoder-related.
80
81
82
83
84
85
86
87
88
        # Calculate encoder cache size if applicable
        # NOTE: For now we use the same budget for both compute and space.
        # This can be changed when we make encoder cache for embedding caching
        # across requests.
        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
        )

89
90
91
        # NOTE(woosuk): Here, "encoder" includes the vision encoder (and
        # projector if needed). Currently, we assume that the encoder also
        # has the Transformer architecture (e.g., ViT).
92
93
94
95
        self.max_num_encoder_input_tokens = encoder_compute_budget
        # NOTE: For the models without encoder (e.g., text-only models),
        # the encoder cache will not be initialized because cache size is 0
        # for these models.
96
        self.encoder_cache_manager = EncoderCacheManager(
97
            cache_size=encoder_cache_size)
98

99
    def schedule(self) -> "SchedulerOutput":
100
101
        # NOTE(woosuk) on the scheduling algorithm:
        # There's no "decoding phase" nor "prefill phase" in the scheduler.
102
103
104
        # Each request just has the num_computed_tokens and
        # num_tokens_with_spec. num_tokens_with_spec =
        # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
105
106
        # At each step, the scheduler tries to assign tokens to the requests
        # so that each request's num_computed_tokens can catch up its
107
108
109
        # num_tokens_with_spec. This is general enough to cover
        # chunked prefills, prefix caching, speculative decoding,
        # and the "jump decoding" optimization in the future.
110
111
112
113
114

        scheduled_new_reqs: List[Request] = []
        scheduled_resumed_reqs: List[Request] = []
        scheduled_running_reqs: List[Request] = []
        preempted_reqs: List[Request] = []
115
116
117
118

        req_to_new_block_ids: Dict[str, List[int]] = {}
        num_scheduled_tokens: Dict[str, int] = {}
        token_budget = self.max_num_scheduled_tokens
119
120
121
        # Encoder-related.
        scheduled_encoder_inputs: Dict[str, List[int]] = {}
        encoder_budget = self.max_num_encoder_input_tokens
122
123
        # Spec decode-related.
        scheduled_spec_decode_tokens: Dict[str, List[int]] = {}
124
125

        # For logging.
126
127
        scheduled_timestamp = time.monotonic()

128
129
        # First, schedule the RUNNING requests.
        req_index = 0
130
        while req_index < len(self.running) and token_budget > 0:
131
            request = self.running[req_index]
132
133
134
135
136
            if request.request_id in self.scheduled_req_ids:
                # This request has already been scheduled.
                req_index += 1
                continue

137
138
            num_new_tokens = (request.num_tokens_with_spec -
                              request.num_computed_tokens)
139
140
141
            num_new_tokens = min(num_new_tokens, token_budget)
            assert num_new_tokens > 0

142
143
144
145
146
147
            # Schedule encoder inputs.
            encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = (
                self._try_schedule_encoder_inputs(request,
                                                  request.num_computed_tokens,
                                                  num_new_tokens,
                                                  encoder_budget))
148
149
150
151
152
153
154
155
            if num_new_tokens == 0:
                # The request cannot be scheduled because the encoder budget
                # or the encoder cache is exhausted.
                # NOTE(woosuk): Here, by doing `continue` instead of `break`,
                # we do not strictly follow the FCFS scheduling policy and
                # allow the lower-priority requests to be scheduled.
                req_index += 1
                continue
156

157
            while True:
158
                new_blocks = self.kv_cache_manager.allocate_slots(
159
                    request, num_new_tokens)
Cody Yu's avatar
Cody Yu committed
160
                if new_blocks is None:
161
162
163
164
165
166
                    # The request cannot be scheduled.
                    # Preempt the lowest-priority request.
                    preempted_req = self.running.pop()
                    self.kv_cache_manager.free(preempted_req)
                    preempted_req.status = RequestStatus.PREEMPTED
                    preempted_req.num_computed_tokens = 0
167
                    self.request_preempted(preempted_req, scheduled_timestamp)
168
169
170
171
172

                    self.waiting.appendleft(preempted_req)
                    preempted_reqs.append(preempted_req)
                    if preempted_req == request:
                        # No more request to preempt.
173
                        can_schedule = False
174
175
176
                        break
                else:
                    # The request can be scheduled.
177
                    can_schedule = True
178
                    break
179
180
            if not can_schedule:
                break
181
            assert new_blocks is not None
182
183
184

            # Schedule the request.
            scheduled_running_reqs.append(request)
185
            self.scheduled_req_ids.add(request.request_id)
186
187
188
189
190
191
192
            req_to_new_block_ids[request.request_id] = [
                b.block_id for b in new_blocks
            ]
            num_scheduled_tokens[request.request_id] = num_new_tokens
            token_budget -= num_new_tokens
            req_index += 1

193
194
195
196
197
198
            # Speculative decode related.
            if request.spec_token_ids:
                num_scheduled_spec_tokens = (num_new_tokens +
                                             request.num_computed_tokens -
                                             request.num_tokens)
                if num_scheduled_spec_tokens > 0:
199
200
                    # Trim spec_token_ids list to num_scheduled_spec_tokens.
                    del request.spec_token_ids[num_scheduled_spec_tokens:]
201
                    scheduled_spec_decode_tokens[request.request_id] = (
202
                        request.spec_token_ids)
203

204
205
206
207
208
209
210
211
            # Encoder-related.
            if encoder_inputs_to_schedule:
                scheduled_encoder_inputs[request.request_id] = (
                    encoder_inputs_to_schedule)
                # Allocate the encoder cache.
                for i in encoder_inputs_to_schedule:
                    self.encoder_cache_manager.allocate(request, i)
                encoder_budget = new_encoder_budget
212

213
214
215
216
217
218
219
220
        # Record the LoRAs in scheduled_running_reqs
        requested_loras: Set[int] = set()
        if self.lora_config:
            requested_loras = set(
                req.lora_request.lora_int_id for req in scheduled_running_reqs
                if req.lora_request and req.lora_request.lora_int_id > 0)
            assert len(requested_loras) <= self.lora_config.max_loras

221
222
        # Next, schedule the WAITING requests.
        if not preempted_reqs:
223
            while self.waiting and token_budget > 0:
224
225
226
227
                if len(self.running) == self.max_num_running_reqs:
                    break

                request = self.waiting[0]
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

                # Check that adding the request still respects the max_loras
                # constraint.
                if self.lora_config and request.lora_request:
                    req_lora_id = request.lora_request.lora_int_id
                    if len(requested_loras) == self.lora_config.max_loras and (
                            req_lora_id not in requested_loras):
                        # Cannot schedule.
                        # TODO (varun): This means all the other requests in
                        # the WAITING queue will be blocked by this request,
                        # even if,
                        # 1. these other requests do not use LoRA, or,
                        # 2. these other requests use the already requested
                        # LoRAs.
                        # This is too conservative and could be optimized.
                        break

245
                # Get already-cached tokens.
246
247
                computed_blocks, num_computed_tokens = \
                    self.kv_cache_manager.get_computed_blocks(request)
248
249
250
251
252
                # Number of tokens to be scheduled.
                # We use `request.num_tokens` instead of
                # `request.num_prompt_tokens` to consider the resumed requests,
                # which have output tokens.
                num_new_tokens = request.num_tokens - num_computed_tokens
Cody Yu's avatar
Cody Yu committed
253
                if num_new_tokens == 0:
254
                    # This happens when prompt length is divisible by the block
Cody Yu's avatar
Cody Yu committed
255
                    # size and all blocks are cached. Now we force to recompute
256
257
258
259
260
261
262
                    # the last block. Note that we have to re-compute an entire
                    # block because allocate_slots() assumes num_computed_tokens
                    # is always a multiple of the block size. This limitation
                    # can potentially be removed in the future to slightly
                    # improve the performance.
                    num_computed_tokens -= self.block_size
                    num_new_tokens = self.block_size
Cody Yu's avatar
Cody Yu committed
263
                    computed_blocks.pop()
264
265
                num_new_tokens = min(num_new_tokens, token_budget)
                assert num_new_tokens > 0
266
267
268
269
270
271
272
273
274
275

                # Schedule encoder inputs.
                (encoder_inputs_to_schedule, num_new_tokens,
                 new_encoder_budget) = self._try_schedule_encoder_inputs(
                     request, num_computed_tokens, num_new_tokens,
                     encoder_budget)
                if num_new_tokens == 0:
                    # The request cannot be scheduled.
                    break

Cody Yu's avatar
Cody Yu committed
276
277
278
                new_blocks = self.kv_cache_manager.allocate_slots(
                    request, num_new_tokens, computed_blocks)
                if new_blocks is None:
279
280
281
282
283
                    # The request cannot be scheduled.
                    break

                self.waiting.popleft()
                self.running.append(request)
284
                self.scheduled_req_ids.add(request.request_id)
285
                self.request_scheduled(request, scheduled_timestamp)
286
287
288
289
290
291
292
293
                if request.status == RequestStatus.WAITING:
                    scheduled_new_reqs.append(request)
                elif request.status == RequestStatus.PREEMPTED:
                    scheduled_resumed_reqs.append(request)
                else:
                    raise RuntimeError(
                        f"Invalid request status: {request.status}")

294
295
                if self.lora_config and request.lora_request:
                    requested_loras.add(request.lora_request.lora_int_id)
Cody Yu's avatar
Cody Yu committed
296
297
298
                req_to_new_block_ids[request.request_id] = [
                    b.block_id for b in computed_blocks + new_blocks
                ]
299
300
301
                num_scheduled_tokens[request.request_id] = num_new_tokens
                token_budget -= num_new_tokens
                request.status = RequestStatus.RUNNING
302
303
304
305
306
307
308
309
310
311
                request.num_computed_tokens = num_computed_tokens

                # Encoder-related.
                if encoder_inputs_to_schedule:
                    scheduled_encoder_inputs[request.request_id] = (
                        encoder_inputs_to_schedule)
                    # Allocate the encoder cache.
                    for i in encoder_inputs_to_schedule:
                        self.encoder_cache_manager.allocate(request, i)
                    encoder_budget = new_encoder_budget
312
313
314
315
316
317

        # Check if the scheduling constraints are satisfied.
        total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
        assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
        assert token_budget >= 0
        assert len(self.running) <= self.max_num_running_reqs
318
319
320
        # Since some requests in the RUNNING queue may not be scheduled in
        # this step, the total number of scheduled requests can be smaller than
        # len(self.running).
321
        assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
322
                len(scheduled_running_reqs) <= len(self.running))
323

324
325
        # Get the longest common prefix among all requests in the running queue.
        # This can be potentially used for cascade attention.
326
        num_common_prefix_blocks = 0
327
328
329
330
331
332
        if self.running:
            any_request = self.running[0]
            num_common_prefix_blocks = (
                self.kv_cache_manager.get_num_common_prefix_blocks(
                    any_request, len(self.running)))

333
334
335
        # Construct the scheduler output.
        new_reqs_data = [
            NewRequestData.from_request(req,
336
                                        req_to_new_block_ids[req.request_id])
337
338
339
            for req in scheduled_new_reqs
        ]
        resumed_reqs_data = [
340
341
            self._make_cached_request_data(
                req,
342
343
                num_scheduled_tokens[req.request_id],
                len(scheduled_spec_decode_tokens.get(req.request_id, ())),
344
345
346
                req_to_new_block_ids[req.request_id],
                resumed_from_preemption=True,
            ) for req in scheduled_resumed_reqs
347
348
        ]
        running_reqs_data = [
349
350
            self._make_cached_request_data(
                req,
351
352
                num_scheduled_tokens[req.request_id],
                len(scheduled_spec_decode_tokens.get(req.request_id, ())),
353
354
355
                req_to_new_block_ids[req.request_id],
                resumed_from_preemption=False,
            ) for req in scheduled_running_reqs
356
357
358
        ]
        scheduler_output = SchedulerOutput(
            scheduled_new_reqs=new_reqs_data,
359
            scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
360
361
            num_scheduled_tokens=num_scheduled_tokens,
            total_num_scheduled_tokens=total_num_scheduled_tokens,
362
            scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
363
            scheduled_encoder_inputs=scheduled_encoder_inputs,
364
            num_common_prefix_blocks=num_common_prefix_blocks,
365
366
367
368
369
            # finished_req_ids is an existing state in the scheduler,
            # instead of being newly scheduled in this step.
            # It contains the request IDs that are finished in between
            # the previous and the current steps.
            finished_req_ids=self.finished_req_ids,
370
            free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
371
372
373
374
375
        )

        self.finished_req_ids = set()
        return scheduler_output

376
    def _make_cached_request_data(
377
378
        self,
        request: Request,
379
380
        num_scheduled_tokens: int,
        num_scheduled_spec_tokens: int,
381
        new_block_ids: List[int],
382
383
384
        resumed_from_preemption: bool,
    ) -> "CachedRequestData":
        # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
385
        # them at each scheduling step.
386
387
388
389
390
391
        num_computed_tokens = request.num_computed_tokens
        num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
        new_token_ids = request.all_token_ids[
            num_computed_tokens:num_computed_tokens + num_regular_tokens]
        req_data = self._cached_reqs_data.get(request.request_id)
        if req_data is not None:
392
            req_data.resumed_from_preemption = resumed_from_preemption
393
            req_data.new_token_ids = new_token_ids
394
395
396
            req_data.new_block_ids = new_block_ids
            req_data.num_computed_tokens = num_computed_tokens
        else:
397
398
            req_data = CachedRequestData.from_request(request,
                                                      resumed_from_preemption,
399
400
                                                      new_token_ids,
                                                      new_block_ids)
401
            self._cached_reqs_data[request.request_id] = req_data
402
403
        return req_data

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    def _try_schedule_encoder_inputs(
        self,
        request: Request,
        num_computed_tokens: int,
        num_new_tokens: int,
        encoder_budget: int,
    ) -> Tuple[List[int], int, int]:
        """
        Determine which encoder inputs need to be scheduled in the current step,
        and update `num_new_tokens` and encoder token budget accordingly.

        An encoder input will be scheduled if:
        - Its output tokens overlap with the range of tokens being computed
        in this step, i.e.,
        [num_computed_tokens, num_computed_tokens + num_new_tokens).
        - It is not already computed and stored in the encoder cache.
        - There is sufficient encoder token budget to process it.
        - The encoder cache has space to store it.

        If an encoder input cannot be scheduled due to cache or budget
        limitations, the method adjusts `num_new_tokens` to schedule only the
        decoder tokens up to just before the unschedulable encoder input.
        """
        if not request.has_encoder_inputs():
            return [], num_new_tokens, encoder_budget

        encoder_inputs_to_schedule: List[int] = []
        mm_positions = request.mm_positions
        assert mm_positions is not None
        assert len(mm_positions) > 0
        for i, pos_info in enumerate(mm_positions):
            start_pos = pos_info["offset"]
            num_encoder_tokens = pos_info["length"]

            # The encoder output is needed if the two ranges overlap:
            # [num_computed_tokens, num_computed_tokens + num_new_tokens) and
            # [start_pos, start_pos + num_encoder_tokens)
            if start_pos >= num_computed_tokens + num_new_tokens:
                # The encoder input is not needed in this step.
                break
            if start_pos + num_encoder_tokens <= num_computed_tokens:
                # The encoder input is already computed and stored
                # in the decoder's KV cache.
                continue

            if self.encoder_cache_manager.has_cache(request, i):
                # The encoder input is already computed and cached.
                continue
452
453
454
455
456
            if (not self.encoder_cache_manager.can_allocate(request, i)
                    or num_encoder_tokens > encoder_budget):
                # The encoder cache is full or the encoder budget is exhausted.
                # NOTE(woosuk): We assume that the encoder input tokens should
                # be processed altogether, as the encoder usually uses
457
                # bidirectional attention.
458
459
460
461
462
463
464
465
466
467
                if num_computed_tokens < start_pos:
                    # We only schedule the decoder tokens just before the
                    # encoder input.
                    num_new_tokens = start_pos - num_computed_tokens
                else:
                    # Because of prefix caching, num_computed_tokens is greater
                    # than start_pos even though its encoder input is not
                    # available. In this case, we can't schedule any token for
                    # the request in this step.
                    num_new_tokens = 0
468
469
470
471
472
473
                break

            encoder_budget -= num_encoder_tokens
            encoder_inputs_to_schedule.append(i)
        return encoder_inputs_to_schedule, num_new_tokens, encoder_budget

474
475
476
477
    def update_from_output(
        self,
        scheduler_output: "SchedulerOutput",
        model_runner_output: "ModelRunnerOutput",
478
    ) -> EngineCoreOutputs:
479
        sampled_token_ids = model_runner_output.sampled_token_ids
480
        spec_token_ids = model_runner_output.spec_token_ids
481
482
        logprobs = model_runner_output.logprobs
        prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
483
        num_scheduled_tokens = scheduler_output.num_scheduled_tokens
484

485
        new_running: List[Request] = []
486
        outputs: List[EngineCoreOutput] = []
487
488
489
490

        # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
        # loop can be a performance bottleneck. We should do our best to avoid
        # expensive operations inside the loop.
491
492
        for request in self.running:
            req_id = request.request_id
493
494
495
496
497
498
            num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
            if num_tokens_scheduled == 0:
                # The request was not scheduled in this step.
                new_running.append(request)
                continue

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
            req_index = model_runner_output.req_id_to_index[req_id]
            generated_token_ids = sampled_token_ids[req_index]
            if req_id not in scheduler_output.scheduled_spec_decode_tokens:
                # When the request's num_computed_tokens catches up
                # its num_tokens, the request generates output tokens.
                # Otherwise, we ignore the sampler output for the request.
                request.num_computed_tokens += num_tokens_scheduled
                assert request.num_computed_tokens <= request.num_tokens
            else:
                # num_computed_tokens_step represents the number of tokens
                # processed in the current step, considering scheduled
                # tokens and rejections.
                # It is calculated as:
                # num_computed_tokens_step = num_scheduled_tokens -
                #                            num_tokens_rejected,
                # where num_tokens_rejected is given by:
                # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
                scheduled_spec_token_ids = (
                    scheduler_output.scheduled_spec_decode_tokens[req_id])

                num_computed_tokens_step = num_scheduled_tokens[req_id] - (
                    len(scheduled_spec_token_ids) + 1 -
                    len(generated_token_ids))
                request.num_computed_tokens += num_computed_tokens_step
523
524
525

            cached_encoder_input_ids = (
                self.encoder_cache_manager.get_cached_input_ids(request))
526
527
528
529
530
531
532
533
            # OPTIMIZATION: Avoid list(set) if the set is empty.
            if cached_encoder_input_ids:
                for input_id in list(cached_encoder_input_ids):
                    start_pos = request.mm_positions[input_id]["offset"]
                    num_tokens = request.mm_positions[input_id]["length"]
                    if start_pos + num_tokens <= request.num_computed_tokens:
                        # The encoder output is already processed and stored
                        # in the decoder's KV cache.
534
535
                        self.encoder_cache_manager.free_encoder_input(
                            request, input_id)
536

537
538
539
            # Add newly generated spec token ids to the request.
            if spec_token_ids is not None:
                request.spec_token_ids = spec_token_ids[req_index]
540

541
542
543
544
545
            # Get prompt logprobs for this request.
            prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)

            stopped = False
            new_logprobs = None
546
547
548
549
550
551
552
553
554
555
556
557
558
            new_token_ids: List[int] = []

            if request.num_computed_tokens >= request.num_tokens:
                for output_token_id in generated_token_ids:
                    request.append_output_token_ids(output_token_id)
                    new_token_ids.append(output_token_id)

                    # Check for stop and update request state.
                    # This must be called before we make the EngineCoreOutput.
                    stopped = self._check_stop(request)
                    if stopped:
                        self._free_request(request)
                        break
559

560
561
562
563
564
565
566
567
568
                # Extract sample logprobs if needed.
                if request.sampling_params.logprobs is not None:
                    assert logprobs is not None
                    # NOTE: once we support N tokens per step (spec decode),
                    # the outer lists can be of length > 1.
                    new_logprobs = logprobs.slice(req_index, req_index + 1)

            # Transmit partial if chunked prefill & prompt logprobs is enabled
            if new_token_ids or prompt_logprobs_tensors is not None:
569
                # Add EngineCoreOutput for this Request.
570
571
572
                outputs.append(
                    EngineCoreOutput(
                        request_id=req_id,
573
                        new_token_ids=new_token_ids,
574
575
576
                        finish_reason=request.get_finished_reason(),
                        new_logprobs=new_logprobs,
                        new_prompt_logprobs_tensors=prompt_logprobs_tensors,
577
578
                        stop_reason=request.stop_reason,
                        events=request.take_events()))
579

580
            self.scheduled_req_ids.remove(request.request_id)
581
582
            if not stopped:
                new_running.append(request)
583
584

        self.running = new_running
585
586
587
588
        return EngineCoreOutputs(
            outputs=outputs,
            scheduler_stats=self.make_stats(),
        )
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611

    def _check_stop(self, request: Request) -> bool:
        if (request.num_tokens >= self.max_model_len
                or request.num_output_tokens >= request.max_tokens):
            request.status = RequestStatus.FINISHED_LENGTH_CAPPED
            return True

        sampling_params = request.sampling_params
        last_token_id = request.output_token_ids[-1]
        if (not sampling_params.ignore_eos
                and last_token_id == request.eos_token_id):
            request.status = RequestStatus.FINISHED_STOPPED
            return True

        if last_token_id in (sampling_params.stop_token_ids or ()):
            request.status = RequestStatus.FINISHED_STOPPED
            request.stop_reason = last_token_id
            return True
        return False

    def add_request(self, request: Request) -> None:
        self.waiting.append(request)
        self.requests[request.request_id] = request
612
        self.request_queued(request)
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636

    def finish_requests(
        self,
        request_ids: Union[str, Iterable[str]],
        finished_status: RequestStatus,
    ) -> None:
        """Handles the finish signal from outside the scheduler.

        For example, the API server can abort a request when the client
        disconnects.
        """
        assert RequestStatus.is_finished(finished_status)
        if isinstance(request_ids, str):
            request_ids = (request_ids, )
        request_ids = set(request_ids)

        for req_id in request_ids:
            request = self.requests.get(req_id)
            if request is None:
                # Invalid request ID.
                continue

            if request.status == RequestStatus.RUNNING:
                self.running.remove(request)
637
638
                if request.request_id in self.scheduled_req_ids:
                    self.scheduled_req_ids.remove(request.request_id)
639
640
641
642
643
644
645
646
            else:
                self.waiting.remove(request)
            request.status = finished_status
            self._free_request(request)

    def _free_request(self, request: Request) -> None:
        assert request.is_finished()
        self.kv_cache_manager.free(request)
647
        self.kv_cache_manager.free_block_hashes(request)
648
        self.encoder_cache_manager.free(request)
649
        self._cached_reqs_data.pop(request.request_id, None)
650
651
652
653
654
655
656
657
658
        del self.requests[request.request_id]
        self.finished_req_ids.add(request.request_id)

    def get_num_unfinished_requests(self) -> int:
        return len(self.waiting) + len(self.running)

    def has_unfinished_requests(self) -> bool:
        return self.get_num_unfinished_requests() > 0

659
660
661
662
    def get_num_unscheduled_requests(self) -> int:
        """Number of requests that are not being processed by the executor."""
        return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)

663
664
665
    def reset_prefix_cache(self) -> bool:
        return self.kv_cache_manager.reset_prefix_cache()

666
667
668
669
670
671
672
673
674
675
676
677
678
    def request_queued(self, request: Request):
        if not self.log_stats:
            return
        request.events.append(
            EngineCoreEvent.new_event(EngineCoreEventType.QUEUED))

    def request_scheduled(self, request: Request, timestamp: float):
        if not self.log_stats:
            return
        request.events.append(
            EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED,
                                      timestamp))

679
680
681
682
683
684
685
    def request_preempted(self, request: Request, timestamp: float):
        if not self.log_stats:
            return
        request.events.append(
            EngineCoreEvent.new_event(EngineCoreEventType.PREEMPTED,
                                      timestamp))

686
687
688
    def make_stats(self) -> Optional[SchedulerStats]:
        if not self.log_stats:
            return None
689
690
691
        return SchedulerStats(
            num_running_reqs=len(self.running),
            num_waiting_reqs=len(self.waiting),
692
            gpu_cache_usage=self.kv_cache_manager.usage,
693
            prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
694
        )