llm_engine.py 77.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Antoni Baum's avatar
Antoni Baum committed
4
import time
5
from collections import Counter as collectionsCounter
6
from collections import deque
7
from contextlib import contextmanager
8
from dataclasses import dataclass
9
from functools import partial
10
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
11
                    Iterable, List, Literal, Mapping, NamedTuple, Optional)
12
from typing import Sequence as GenericSequence
13
from typing import Set, Type, Union, cast
14

15
import torch
16
from typing_extensions import TypeVar
17

18
import vllm.envs as envs
19
20
21
from vllm.config import (DecodingConfig, ModelConfig, ObservabilityConfig,
                         ParallelConfig, SchedulerConfig, VllmConfig)
from vllm.config.lora import LoRAConfig
22
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.engine.arg_utils import EngineArgs
24
from vllm.engine.metrics_types import StatLoggerBase, Stats
25
26
27
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
28
29
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
30
from vllm.executor.executor_base import ExecutorBase
31
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
32
from vllm.inputs.parse import split_enc_dec_inputs
33
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.logger import init_logger
35
from vllm.logits_process import get_bad_words_logits_processors
36
from vllm.lora.request import LoRARequest
37
from vllm.model_executor.layers.sampler import SamplerOutput
38
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
39
from vllm.multimodal.cache import processor_only_cache_from_config
40
from vllm.multimodal.processing import EncDecMultiModalProcessor
41
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
42
                          RequestOutputFactory)
43
from vllm.reasoning import ReasoningParser, ReasoningParserManager
44
from vllm.sampling_params import RequestOutputKind, SamplingParams
45
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
46
47
48
                           Sequence, SequenceGroup, SequenceGroupBase,
                           SequenceGroupMetadata, SequenceGroupOutput,
                           SequenceStatus)
49
50
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
51
from vllm.transformers_utils.detokenizer import Detokenizer
52
53
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
54
55
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
56
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
57
from vllm.version import __version__ as VLLM_VERSION
58
from vllm.worker.model_runner_base import InputProcessingError
59
60

logger = init_logger(__name__)
61
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
62

63
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
64
_R = TypeVar("_R", default=Any)
65
66


67
68
69
70
71
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
72
73
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
74
75


76
77
78
79
80
81
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
82
83
84
85
86
87
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
88
89
90
    skip: List[int]


91
class SchedulerContext:
92

93
    def __init__(self) -> None:
94
        self.output_queue: Deque[OutputData] = deque()
95
        self.request_outputs: List[RequestOutput] = []
96
97
98
99
100
101
102
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
103
104
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
105
106
107
108
109
110
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
111
                       is_first_step_output=is_first_step_output,
112
                       skip=[]))
113
114


115
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
116
    """An LLM engine that receives requests and generates texts.
117

Woosuk Kwon's avatar
Woosuk Kwon committed
118
    This is the main class for the vLLM engine. It receives requests
119
120
121
122
123
124
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

125
126
127
    The [`LLM`][vllm.LLM] class wraps this class for offline batched inference
    and the [`AsyncLLMEngine`][vllm.engine.async_llm_engine.AsyncLLMEngine]
    class wraps this class for online serving.
128

129
    The config arguments are derived from [`EngineArgs`][vllm.EngineArgs].
130
131

    Args:
132
        vllm_config: The configuration for initializing and running vLLM.
133
134
        executor_class: The model executor class for managing distributed
            execution.
135
        log_stats: Whether to log statistics.
136
        usage_context: Specified entry point, used for usage info collection.
137
    """
138

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

164
        return cast(_O, output)
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

188
    tokenizer: Optional[AnyTokenizer]
189

190
191
    def __init__(
        self,
192
        vllm_config: VllmConfig,
193
        executor_class: Type[ExecutorBase],
194
        log_stats: bool,
yhu422's avatar
yhu422 committed
195
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
196
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
197
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
198
        use_cached_outputs: bool = False,
199
    ) -> None:
200
201
202
203
204
205
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
206

207
        self.vllm_config = vllm_config
208
209
210
211
212
213
214
215
216
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
217
        )
218
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
219
220
        )

221
        logger.info(
222
            "Initializing a V0 LLM engine (v%s) with config: %s, "
223
            "use_cached_outputs=%s, ",
224
            VLLM_VERSION,
225
            vllm_config,
226
            use_cached_outputs,
227
        )
228

229
        self.log_stats = log_stats
230
        self.use_cached_outputs = use_cached_outputs
231

232
        if self.model_config.skip_tokenizer_init:
233
            self.tokenizer = None
234
            self.detokenizer = None
235
236
237
        else:
            self.tokenizer = self._init_tokenizer()
            self.detokenizer = Detokenizer(self.tokenizer)
238

239
        self.seq_counter = Counter()
240
241
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
242

243
244
245
246
247
248
249
        self.input_preprocessor = InputPreprocessor(
            self.model_config,
            self.tokenizer,
            mm_registry,
            mm_processor_cache=processor_only_cache_from_config(
                self.model_config, mm_registry),
        )
250

251
        self.model_executor = executor_class(vllm_config=vllm_config)
252

253
        self._initialize_kv_caches()
254

yhu422's avatar
yhu422 committed
255
256
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
257
258
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
259
            usage_message.report_usage(
260
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
261
262
263
264
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
265
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
266
                    "tensor_parallel_size":
267
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
268
                    "block_size":
269
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
270
                    "gpu_memory_utilization":
271
                    self.cache_config.gpu_memory_utilization,
272
273
                    "kv_cache_memory_bytes":
                    self.cache_config.kv_cache_memory_bytes,
yhu422's avatar
yhu422 committed
274
275
                    # Quantization
                    "quantization":
276
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
277
                    "kv_cache_dtype":
278
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
279
280
281

                    # Feature flags
                    "enable_lora":
282
                    bool(self.lora_config),
yhu422's avatar
yhu422 committed
283
                    "enable_prefix_caching":
284
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
285
                    "enforce_eager":
286
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
287
                    "disable_custom_all_reduce":
288
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
289
290
                })

291
292
293
294
295
296
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
297
            SchedulerContext()
298
299
300
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

301
        if self.model_config.use_async_output_proc:
302
303
304
305
306
307
308
309
310
            process_model_outputs = weak_bind(self._process_model_outputs)

            self.async_callbacks = [
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
        else:
            self.async_callbacks = []
311
312
313

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
314
        self.process_request_outputs_callback: Optional[Callable] = None
315

316
        # Create the scheduler.
317
318
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
319
320
321
322
323
        if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
            Scheduler = resolve_obj_by_qualname(
                self.vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = self.vllm_config.scheduler_config.scheduler_cls
324
        self.scheduler = [
325
            Scheduler(
326
327
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
328
                self.async_callbacks[v_id]
329
330
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
331
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
332

333
334
        # Metric Logging.
        if self.log_stats:
335
336
337
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
338
339
340
341
342
343
344
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

345
346
347
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
348
349
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        vllm_config=vllm_config),
350
351
352
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
353
354
                        labels=dict(
                            model_name=self.model_config.served_model_name),
355
                        vllm_config=vllm_config),
356
357
358
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
359

360
361
362
363
364
365
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

366
367
368
369
370
371
372
373
        # Initialize reasoning parser if reasoning backend is set.
        if self.decoding_config.reasoning_backend and \
                self.tokenizer:
            reasoner_class = ReasoningParserManager.get_reasoning_parser(
                self.decoding_config.reasoning_backend)
            self.reasoner: ReasoningParser = reasoner_class(
                self.tokenizer.get_lora_tokenizer())

374
375
376
377
378
379
380
381
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
382
383
384
385
386
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
                    self.reasoner if self.decoding_config.reasoning_backend
                    and self.tokenizer else None,
                ),
387
388
            ))

389
390
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

391
392
393
394
        # Flag to set when an input fails to process and the engine should run
        # the next step without re-scheduling.
        self._skip_scheduling_next_step = False

395
396
397
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

398
399
400
401
402
403
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
404
        start = time.time()
405
406
407
408
409
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
410
411
412
413
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
414
415
416
417
418
419
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
420
421
422
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
423

424
    @classmethod
425
    def _get_executor_cls(cls,
426
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
427
        # distributed_executor_backend must be set in VllmConfig.__post_init__
428
429
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
430
        # Initialize the cluster and specify the executor class.
431
432
433
434
435
436
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
437
438
439
440
441
442
443
444
445
446
447
448
449
        elif distributed_executor_backend == "ray":
            from vllm.executor.ray_distributed_executor import (
                RayDistributedExecutor)
            executor_class = RayDistributedExecutor
        elif distributed_executor_backend == "mp":
            from vllm.executor.mp_distributed_executor import (
                MultiprocessingDistributedExecutor)
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
            executor_class = MultiprocessingDistributedExecutor
        elif distributed_executor_backend == "uni":
            # JAX-style, single-process, multi-device executor.
450
451
            from vllm.executor.uniproc_executor import UniProcExecutor
            executor_class = UniProcExecutor
452
453
454
455
456
457
458
459
        elif distributed_executor_backend == "external_launcher":
            # executor with external launcher
            from vllm.executor.uniproc_executor import (  # noqa
                ExecutorWithExternalLauncher)
            executor_class = ExecutorWithExternalLauncher
        else:
            raise ValueError("unrecognized distributed_executor_backend: "
                             f"{distributed_executor_backend}")
460
461
        return executor_class

462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

478
479
480
481
482
483
484
485
486
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
487
488
489
490
491
492
493
494
495
        vllm_config = engine_args.create_engine_config(usage_context)

        engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            engine_cls = V1LLMEngine

        return engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
496
            usage_context=usage_context,
497
            stat_loggers=stat_loggers,
498
            disable_log_stats=engine_args.disable_log_stats,
yhu422's avatar
yhu422 committed
499
        )
500

501
502
503
504
505
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

506
507
508
509
510
511
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

512
    def get_tokenizer(self) -> AnyTokenizer:
513
        if self.tokenizer is None:
514
515
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
516

517
        return self.tokenizer
518

519
520
    def _init_tokenizer(self) -> AnyTokenizer:
        return init_tokenizer_from_configs(model_config=self.model_config)
521

522
523
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
524
        self.cache_config.verify_with_parallel_config(self.parallel_config)
525
526
527
528
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
529

530
531
532
    def _add_processed_request(
        self,
        request_id: str,
533
        processed_inputs: ProcessorInputs,
534
        params: SamplingParams,
535
536
        arrival_time: float,
        lora_request: Optional[LoRARequest],
537
        trace_headers: Optional[Mapping[str, str]] = None,
538
        priority: int = 0,
539
    ) -> Optional[SequenceGroup]:
540
541
542
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
543
544
545
546
547
548
549
550
551
552
553
554
555
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
            )
            return None

556
        self._validate_model_inputs(processed_inputs)
557
558
559
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
560
        eos_token_id = self.input_preprocessor.get_eos_token_id()
561

562
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
563
564

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
565
                       lora_request)
566

567
        encoder_seq = (None if encoder_inputs is None else Sequence(
568
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
569

570
        # Create a SequenceGroup based on SamplingParams
571
572
573
574
575
576
577
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
578
                trace_headers=trace_headers,
579
580
                encoder_seq=encoder_seq,
                priority=priority)
581
        else:
582
            raise ValueError("SamplingParams must be provided.")
583

584
585
586
587
588
589
590
591
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

592
593
        return seq_group

594
595
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
596

597
598
599
    def add_request(
        self,
        request_id: str,
600
        prompt: PromptType,
601
        params: SamplingParams,
602
        arrival_time: Optional[float] = None,
603
        lora_request: Optional[LoRARequest] = None,
604
        tokenization_kwargs: Optional[dict[str, Any]] = None,
605
        trace_headers: Optional[Mapping[str, str]] = None,
606
        priority: int = 0,
607
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
608
        """Add a request to the engine's request pool.
609
610

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
611
        scheduler as `engine.step()` is called. The exact scheduling policy is
612
613
614
615
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
616
617
            prompt: The prompt to the LLM. See
                [PromptType][vllm.inputs.PromptType]
618
                for more details about the format of each input.
619
            params: Parameters for sampling.
620
                [SamplingParams][vllm.SamplingParams] for text generation.
621
            arrival_time: The arrival time of the request. If None, we use
622
                the current monotonic time.
623
            lora_request: The LoRA request to add.
624
            trace_headers: OpenTelemetry trace headers.
625
626
            priority: The priority of the request.
                Only applicable with priority scheduling.
627
628
629
630

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
631
632
633
634
            - Create `n` number of [Sequence][vllm.sequence.Sequence] objects.
            - Create a [SequenceGroup][vllm.sequence.SequenceGroup] object
              from the list of [Sequence][vllm.sequence.Sequence].
            - Add the [SequenceGroup][vllm.sequence.SequenceGroup] object to the
635
              scheduler.
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
652
        """
653
654
655
656
        if not isinstance(request_id, str):
            raise TypeError(
                f"request_id must be a string, got {type(request_id)}")

657
658
659
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
660

661
        if priority != 0 and not self.scheduler_config.policy == "priority":
662
663
664
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

665
        if isinstance(params, SamplingParams) \
666
            and params.logits_processors:
667
            raise ValueError(
668
                "Logits processors are not supported in multi-step decoding")
669

670
        if arrival_time is None:
671
            arrival_time = time.time()
672

673
        if (isinstance(prompt, dict)
674
675
676
677
678
679
680
                and prompt.get("prompt_embeds", None) is not None):
            if not prompt.get("prompt_token_ids", None):
                seq_len = prompt["prompt_embeds"].shape[0]
                prompt["prompt_token_ids"] = [0] * seq_len
            if params.prompt_logprobs is not None:
                raise ValueError(
                    "prompt_logprobs is not compatible with prompt embeds.")
681

682
        processed_inputs = self.input_preprocessor.preprocess(
683
            prompt,
684
            tokenization_kwargs=tokenization_kwargs,
685
        )
686

687
        self._add_processed_request(
688
689
690
691
692
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
693
            trace_headers=trace_headers,
694
            priority=priority,
695
        )
696
697
698
699
700
701

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
702
703
        arrival_time: float,
        lora_request: Optional[LoRARequest],
704
        trace_headers: Optional[Mapping[str, str]] = None,
705
        encoder_seq: Optional[Sequence] = None,
706
        priority: int = 0,
707
708
709
710
711
712
713
714
715
716
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

717
718
719
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

720
721
722
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
723

724
        sampling_params.update_from_generation_config(
725
            self.generation_config_fields, seq.eos_token_id)
726

727
        # Create the sequence group.
728
729
730
731
        draft_size = 1
        if self.vllm_config.speculative_config is not None:
            draft_size = \
                self.vllm_config.speculative_config.num_speculative_tokens + 1
732
733
734
735
736
737
738
739
740
        seq_group = SequenceGroup(request_id=request_id,
                                  seqs=[seq],
                                  arrival_time=arrival_time,
                                  sampling_params=sampling_params,
                                  lora_request=lora_request,
                                  trace_headers=trace_headers,
                                  encoder_seq=encoder_seq,
                                  priority=priority,
                                  draft_size=draft_size)
741

742
743
        return seq_group

Antoni Baum's avatar
Antoni Baum committed
744
745
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
746
747

        Args:
Antoni Baum's avatar
Antoni Baum committed
748
            request_id: The ID(s) of the request to abort.
749
750

        Details:
751
            - Refer to [vllm.core.scheduler.Scheduler.abort_seq_group][].
752
753
754
755
756
757

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
758
        """
759
        for scheduler in self.scheduler:
760
761
            scheduler.abort_seq_group(
                request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
762

763
764
765
766
    def get_vllm_config(self) -> VllmConfig:
        """Gets the vllm configuration."""
        return self.vllm_config

767
768
769
770
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

771
772
773
774
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

775
776
777
778
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

779
780
781
782
783
784
785
786
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

787
    def get_num_unfinished_requests(self) -> int:
788
        """Gets the number of unfinished requests."""
789
790
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
791

792
    def has_unfinished_requests(self) -> bool:
793
        """Returns True if there are unfinished requests."""
794
795
796
797
798
799
800
801
802
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
803

804
805
    def reset_mm_cache(self) -> bool:
        """Reset the multi-modal cache."""
806
807
        self.input_preprocessor.clear_cache()
        return True
808

809
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
810
811
812
813
        """Reset prefix cache for all devices."""

        success = True
        for scheduler in self.scheduler:
814
            success = success and scheduler.reset_prefix_cache(device)
815
816
        return success

817
818
819
820
821
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
822

823
824
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
825
        """
826

827
        now = time.time()
828

829
        if len(ctx.output_queue) == 0:
830
831
            return None

832
        # Get pending async postprocessor
833
834
835
836
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
837
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
838
839
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
840
841
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
842
843
844
845
846

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

847
        has_multiple_outputs: bool = len(outputs) > 1
848
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
849
850
        assert not has_multiple_outputs
        outputs_by_sequence_group = outputs
851

852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

869
        finished_before: List[int] = []
870
        finished_now: List[int] = []
871
872
873
874
875
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
876
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
877

878
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
879
880
881
882
883

            if seq_group.is_finished():
                finished_before.append(i)
                continue

884
            output: List[SequenceGroupOutput]
885
            if has_multiple_outputs:
886
887
888
889
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

890
            if not is_async:
891
892
                seq_group.update_num_computed_tokens(
                    seq_group_meta.token_chunk_size or 0)
893
894
895

            if outputs:
                for o in outputs:
896
897
898
899
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
900
                                o.model_forward_time or 0)
901
902
903
904
905
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
906
                                o.model_execute_time or 0)
907
908
909
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
910

911
912
913
914
            self.output_processor.process_prompt_logprob(seq_group, output)
            if seq_group_meta.do_sample:
                self.output_processor.process_outputs(seq_group, output,
                                                      is_async)
915

916
917
            if seq_group.is_finished():
                finished_now.append(i)
918

919
920
921
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
922

923
924
            seq_group = scheduled_seq_group.seq_group
            seq_group.maybe_set_first_token_time(now)
925
926
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
927
            request_output = RequestOutputFactory.create(
928
929
930
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
931
932
            if request_output:
                ctx.request_outputs.append(request_output)
933

934
935
936
937
938
939
940
941
942
943
944
945
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

946
947
948
949
950
951
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

        # Create the outputs
952
953
        for i in indices:
            if i in skip or i in finished_before or i in finished_now:
954
955
                continue  # Avoids double processing

956
957
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

958
            seq_group = scheduled_seq_group.seq_group
959
            seq_group.maybe_set_first_token_time(now)
960
961
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
962
            request_output = RequestOutputFactory.create(
963
964
965
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
966
            if request_output:
967
                ctx.request_outputs.append(request_output)
968

969
        # Create outputs only after processing the scheduler's results
970

971
        for seq_group in scheduler_outputs.ignored_seq_groups:
972
973
974
975
976
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

977
            request_output = RequestOutputFactory.create(
978
979
980
981
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
982
983
            if request_output:
                ctx.request_outputs.append(request_output)
984

985
986
987
988
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
989
            ctx.request_outputs.clear()
990

991
992
993
994
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
995
            # Log stats.
996
997
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
998
999

            # Tracing
1000
            self.do_tracing(scheduler_outputs, finished_before)
1001
1002
1003
1004

        return None

    def _advance_to_next_step(
1005
            self, output: SamplerOutput,
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1019
1020
1021
1022
            token_chunk_size = (seq_group_metadata.token_chunk_size
                                if seq_group_metadata.token_chunk_size
                                is not None else 0)
            seq_group.update_num_computed_tokens(token_chunk_size)
1023

1024
1025
1026
            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1027
                    " (i.e sampling_params.n == 1)")
1028
1029
1030
1031
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1032

1033
1034
                seq.append_token_id(sample.output_token, sample.logprobs,
                                    sample.output_embed)
1035

1036
    def step(self) -> List[RequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
1037
1038
        """Performs one decoding iteration and returns newly generated results.

1039
1040
1041
1042
        <figure markdown="span">
        ![Overview of the step function](https://i.imgur.com/sv2HssD.png)
        <figcaption>Overview of the step function</figcaption>
        </figure>
1043
1044

        Details:
1045
1046
        - Step 1: Schedules the sequences to be executed in the next
            iteration and the token blocks to be swapped in/out/copy.
1047

1048
1049
1050
1051
            - Depending on the scheduling policy,
                sequences may be `preempted/reordered`.
            - A Sequence Group (SG) refer to a group of sequences
                that are generated from the same prompt.
1052

1053
1054
        - Step 2: Calls the distributed executor to execute the model.
        - Step 3: Processes the model output. This mainly includes:
1055

1056
1057
1058
1059
            - Decodes the relevant outputs.
            - Updates the scheduled sequence groups with model outputs
                based on its `sampling parameters` (`use_beam_search` or not).
            - Frees the finished sequence groups.
1060

1061
        - Finally, it creates and returns the newly generated results.
1062
1063

        Example:
1064
1065
1066
1067
1068
1069
1070
        ```
        # Please see the example/ folder for more detailed examples.

        # initialize engine and request arguments
        engine = LLMEngine.from_engine_args(engine_args)
        example_inputs = [(0, "What is LLM?",
        SamplingParams(temperature=0.0))]
1071

1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
        # Start the engine with an event loop
        while True:
            if example_inputs:
                req_id, prompt, sampling_params = example_inputs.pop(0)
                engine.add_request(str(req_id),prompt,sampling_params)

            # continue the request processing
            request_outputs = engine.step()
            for request_output in request_outputs:
                if request_output.finished:
                    # return or show the request output

            if not (engine.has_unfinished_requests() or example_inputs):
                break
        ```
Antoni Baum's avatar
Antoni Baum committed
1087
        """
1088
1089
1090
1091
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1092

1093
        # For llm_engine, there is no pipeline parallel support, so the engine
1094
        # used is always 0.
1095
1096
        virtual_engine = 0

1097
1098
        # These are cached outputs from previous iterations. None if on first
        # iteration
1099
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1100
1101
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1102
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1103

1104
1105
        ctx = self.scheduler_contexts[virtual_engine]

1106
1107
1108
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1109
1110
1111
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
1112
1113
1114
1115
1116
        # The scheduler is also skipped if a single request caused the last
        # engine step to fail, and the previous schedule needs to be rerun.
        if not self._has_remaining_steps(
                seq_group_metadata_list
        ) and not self._skip_scheduling_next_step:
1117
            # Schedule iteration
1118
            (seq_group_metadata_list, scheduler_outputs,
1119
1120
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1121

1122
1123
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1124

1125
1126
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
1127
1128
1129
1130
1131
            # When n>1, elements in self.seq_id_to_seq_group should be deleted
            # here, otherwise memory leaks.
            for finished_request_id in finished_requests_ids:
                if finished_request_id in self.seq_id_to_seq_group:
                    del self.seq_id_to_seq_group[finished_request_id]
1132

1133
1134
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1135
                self._process_model_outputs(ctx=ctx)
1136

1137
1138
        else:
            finished_requests_ids = list()
1139
1140
1141

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1142

1143
        if not scheduler_outputs.is_empty():
1144
1145
1146
1147
1148
1149

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1150
                self._get_last_sampled_token_ids(virtual_engine)
1151

1152
            execute_model_req = ExecuteModelRequest(
1153
1154
1155
1156
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1157
1158
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1159
1160
1161
1162
1163
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1164
            if allow_async_output_proc:
1165
1166
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1167

1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
            try:
                outputs = self.model_executor.execute_model(
                    execute_model_req=execute_model_req)
                self._skip_scheduling_next_step = False
            except InputProcessingError as e:
                # The input for this request cannot be processed, so we must
                # abort it. If there are remaining requests in the batch that
                # have been scheduled, they will be retried on the next step.
                invalid_request_id = e.request_id
                self._abort_and_cache_schedule(
                    request_id=invalid_request_id,
                    virtual_engine=virtual_engine,
                    seq_group_metadata_list=seq_group_metadata_list,
                    scheduler_outputs=scheduler_outputs,
                    allow_async_output_proc=allow_async_output_proc)
                # Raise so the caller is notified that this request failed
                raise
1185

1186
        else:
1187
1188
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1189
1190
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1191
            # No outputs in this case
1192
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1193

1194
        if not self._has_remaining_steps(seq_group_metadata_list):
1195
            # is_first_step_output is True only when the num_steps of all
1196
            # the sequences are 1.
1197
1198
1199
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1200
            # Add results to the output_queue
1201
1202
1203
1204
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1205
1206
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1207
1208
1209

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1210
                    "Async postprocessor expects only a single output set")
1211

1212
                self._advance_to_next_step(
1213
                    outputs[0], seq_group_metadata_list,
1214
                    scheduler_outputs.scheduled_seq_groups)
1215

1216
            # Check if need to run the usual non-async path
1217
            if not allow_async_output_proc:
1218
                self._process_model_outputs(ctx=ctx)
1219

1220
                # Log stats.
1221
                self.do_log_stats(scheduler_outputs, outputs)
1222

1223
1224
1225
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1226
            # Multi-step case
1227
            return ctx.request_outputs
1228

1229
        if not self.has_unfinished_requests():
1230
1231
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1232
                self._process_model_outputs(ctx=ctx)
1233
            assert len(ctx.output_queue) == 0
1234

1235
1236
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
1237
            # torch.distributed ops which may otherwise time out, and unblocks
1238
1239
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1240
            logger.debug("Stopping remote worker execution loop.")
1241
1242
            self.model_executor.stop_remote_worker_execution_loop()

1243
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1244

1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
    def _abort_and_cache_schedule(
            self, request_id: str, virtual_engine: int,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        """Aborts a single request, and caches the scheduler outputs minus that
        request. This allows the next step to continue processing the remaining
        requests without having to re-run the scheduler."""

        # Abort the request and remove its sequence group from the current
        # schedule
        self.abort_request(request_id)
        for i, metadata in enumerate(seq_group_metadata_list):
            if metadata.request_id == request_id:
                del seq_group_metadata_list[i]
                break
        for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
            if group.seq_group.request_id == request_id:
                del scheduler_outputs.scheduled_seq_groups[i]
                break

        # If there are still other sequence groups left in the schedule, cache
        # them and flag the engine to reuse the schedule.
        if len(seq_group_metadata_list) > 0:
            self._skip_scheduling_next_step = True
            # Reuse multi-step caching logic
            self._cache_scheduler_outputs_for_multi_step(
                virtual_engine=virtual_engine,
                scheduler_outputs=scheduler_outputs,
                seq_group_metadata_list=seq_group_metadata_list,
                allow_async_output_proc=allow_async_output_proc)

1277
1278
1279
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
1280
        return False
1281
1282
1283
1284

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1285
1286
1287
1288
1289
1290
1291
1292
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        return None

1311
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1312
1313
1314
1315
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1316
1317
1318
1319
1320
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1321
1322
1323
1324
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1325
1326
1327
1328
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1329
1330
1331
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1332
1333
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1334
1335
        """Forced log when no requests active."""
        if self.log_stats:
1336
            stats = self._get_stats(scheduler_outputs, model_output,
1337
                                    finished_before, skip)
1338
            for logger in self.stat_loggers.values():
1339
                logger.log(stats)
1340

1341
1342
1343
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1344
1345
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1346
1347
1348
1349
1350
1351
1352
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1353
1354
1355
1356
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1357
        """
1358
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1359

1360
1361
        # System State
        #   Scheduler State
1362
1363
1364
1365
1366
1367
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1368
1369

        # KV Cache Usage in %
1370
        num_total_gpu = self.cache_config.num_gpu_blocks
1371
        gpu_cache_usage_sys = 0.
1372
        if num_total_gpu:  # Guard against both None and 0
1373
1374
1375
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1376
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1377

1378
        num_total_cpu = self.cache_config.num_cpu_blocks
1379
        cpu_cache_usage_sys = 0.
1380
        if num_total_cpu:  # Guard against both None and 0
1381
1382
1383
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1384
1385
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1386
1387
1388
1389
1390
1391
1392
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
        # Exchange the uasge and cache hit stats between gpu and cpu when
        # running on cpu because the cpu_worker.py intentionally reports the
        # number of cpu blocks as gpu blocks in favor of cache management.
        if self.device_config.device_type == "cpu":
            num_total_gpu, num_total_cpu = num_total_cpu, num_total_gpu
            gpu_cache_usage_sys, cpu_cache_usage_sys = (
                cpu_cache_usage_sys,
                gpu_cache_usage_sys,
            )
            gpu_prefix_cache_hit_rate, cpu_prefix_cache_hit_rate = (
                cpu_prefix_cache_hit_rate,
                gpu_prefix_cache_hit_rate,
            )

1407
1408
1409
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1410
        num_tokens_iter = 0
1411
        time_to_first_tokens_iter: List[float] = []
1412
        inter_token_latencies_iter: List[float] = []
1413
1414
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1415
1416
1417
1418

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1419
1420
1421
1422
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1423
1424
1425
1426
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1427
        max_num_generation_tokens_requests: List[int] = []
1428
        max_tokens_requests: List[int] = []
1429
1430
        finished_reason_requests: List[str] = []

1431
        # LoRA requests
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1450
1451
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1452
        if scheduler_outputs is not None:
1453
1454
1455
1456
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1457
            num_generation_tokens_from_prefill_groups = 0
1458
1459
1460
1461
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1462
1463
1464

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1465
1466
1467
1468
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue
1469
1470
1471
1472
1473

                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1474

1475
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1476
                seq_group = scheduled_seq_group.seq_group
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
1489
                        latency = seq_group.get_last_token_latency()
1490
1491
1492
1493
1494
1495
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
1496
                    # ITLs
1497
                    latency = seq_group.get_last_token_latency()
1498
                    inter_token_latencies_iter.append(latency)
1499
1500
1501
1502
1503
1504
1505
1506
1507
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1508
1509
1510
1511
1512
1513

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1514
                if seq_group.is_finished():
1515
                    # Latency timings
1516
1517
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1530
1531
1532
1533
1534
1535
1536
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1537
1538
1539
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1540
1541
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1542
1543
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1556
                actual_num_batched_tokens - num_prompt_tokens_iter +
1557
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1558
1559
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1560

1561
1562
        return Stats(
            now=now,
1563
1564
1565
1566
1567
1568
1569
1570
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1571
1572
1573
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1574
1575
1576
1577

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1578
            num_tokens_iter=num_tokens_iter,
1579
            time_to_first_tokens_iter=time_to_first_tokens_iter,
1580
            inter_token_latencies_iter=inter_token_latencies_iter,
1581
            num_preemption_iter=num_preemption_iter,
1582
1583
1584
1585

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1586
1587
1588
1589
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1590
1591
1592
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1593
1594
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
1595
            n_requests=n_requests,
1596
            max_tokens_requests=max_tokens_requests,
1597
            finished_reason_requests=finished_reason_requests,
1598
1599
1600
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1601

1602
    def add_lora(self, lora_request: LoRARequest) -> bool:
1603
        return self.model_executor.add_lora(lora_request)
1604
1605

    def remove_lora(self, lora_id: int) -> bool:
1606
        return self.model_executor.remove_lora(lora_id)
1607

1608
    def list_loras(self) -> Set[int]:
1609
        return self.model_executor.list_loras()
1610

1611
1612
1613
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1614
1615
1616
1617
1618
1619
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

1620
1621
1622
1623
1624
    def sleep(self, level: int = 1) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.sleep(level=level)

1625
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
1626
1627
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
1628
        self.model_executor.wake_up(tags)
1629

1630
1631
1632
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

1633
    def check_health(self) -> None:
1634
        self.model_executor.check_health()
1635
1636
1637
1638

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

1639
1640
1641
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
1642
1643
1644
        if self.tracer is None:
            return

1645
1646
1647
1648
1649
1650
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
1668
1669
1670
1671
1672
1673
1674
1675

            # Handle potential None values for cancelled/aborted requests
            ttft = (metrics.first_token_time - metrics.arrival_time
                    if metrics.first_token_time is not None else None)

            e2e_time = (metrics.finished_time - metrics.arrival_time
                        if metrics.finished_time is not None else None)

1676
            seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
1677
                                   self.model_config.model)
1678
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
1679
                                   seq_group.request_id)
1680
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
1681
                                   seq_group.sampling_params.temperature)
1682
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
1683
                                   seq_group.sampling_params.top_p)
1684
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
1685
                                   seq_group.sampling_params.max_tokens)
1686
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
1687
                                   seq_group.sampling_params.n)
1688
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
1689
                                   seq_group.num_seqs())
1690
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
1691
1692
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
1693
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
1694
1695
1696
1697
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709

            # Only set timing attributes if the values are available
            if metrics.time_in_queue is not None:
                seq_span.set_attribute(
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
                    metrics.time_in_queue)
            if ttft is not None:
                seq_span.set_attribute(
                    SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            if e2e_time is not None:
                seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E,
                                       e2e_time)
1710
1711
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
1712
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
1713
1714
1715
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
1716
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
1717
1718
1719
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
1720
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
1721
                    metrics.model_execute_time)
1722

1723
    def _validate_model_inputs(self, inputs: ProcessorInputs):
1724
1725
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

1726
        if encoder_inputs is not None:
1727
            self._validate_model_input(encoder_inputs, prompt_type="encoder")
1728

1729
        self._validate_model_input(decoder_inputs, prompt_type="decoder")
1730

1731
1732
1733
1734
1735
1736
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
1737
        model_config = self.model_config
1738
        tokenizer = self.tokenizer
1739

1740
        prompt_ids = prompt_inputs.get("prompt_token_ids", [])
1741
1742
1743
        if not prompt_ids:
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
1744
            elif prompt_inputs["type"] == "embeds":
1745
                pass
1746
1747
1748
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")

1749
1750
1751
1752
1753
1754
        if tokenizer is not None:
            max_input_id = max(prompt_ids, default=0)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    f"Token id {max_input_id} is out of vocabulary")

1755
        max_prompt_len = self.model_config.max_model_len
1756
        if len(prompt_ids) > max_prompt_len:
1757
            if prompt_type == "encoder" and model_config.is_multimodal_model:
1758
1759
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
1760
1761
1762
                    model_config,
                    tokenizer=tokenizer or object(),  # Dummy if no tokenizer
                )
1763
                assert isinstance(mm_processor, EncDecMultiModalProcessor)
1764

1765
                if mm_processor.pad_dummy_encoder_prompt:
1766
                    return  # Skip encoder length check for Whisper
1767

1768
            if model_config.is_multimodal_model:
1769
                suggestion = (
1770
1771
1772
1773
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
1774
1775
1776
1777
1778
1779
1780
1781
1782
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
1783
1784
1785
1786

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
1787
1788
1789
1790

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
1791
1792
1793
1794
        """Constructs logits processors based on the logits_bias, and
        allowed_token_ids fields in sampling_params. Deletes those fields and
        adds the constructed logits processors to the logits_processors field.
        Returns the modified sampling params."""
1795
1796

        logits_processors = []
1797

1798
        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
1799
            tokenizer = self.get_tokenizer()
1800

1801
            processors = get_openai_logits_processors(
1802
1803
1804
1805
1806
1807
1808
1809
1810
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

1811
        if len(sampling_params.bad_words) > 0:
1812
            tokenizer = self.get_tokenizer()
1813
1814
1815
1816
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

1817
1818
1819
1820
1821
1822
1823
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params
1824

1825
1826
1827
1828
1829
1830
1831
1832
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

1833

1834
1835
1836
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
    LLMEngine = V1LLMEngine  # type: ignore