llm_engine.py 77.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Antoni Baum's avatar
Antoni Baum committed
4
import time
5
from collections import Counter as collectionsCounter
6
from collections import deque
7
from contextlib import contextmanager
8
from dataclasses import dataclass
9
from functools import partial
10
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
11
                    Iterable, List, Literal, Mapping, NamedTuple, Optional)
12
from typing import Sequence as GenericSequence
13
from typing import Set, Type, Union, cast
14

15
import torch
16
from typing_extensions import TypeVar
17

18
import vllm.envs as envs
19
20
21
from vllm.config import (DecodingConfig, ModelConfig, ObservabilityConfig,
                         ParallelConfig, SchedulerConfig, VllmConfig)
from vllm.config.lora import LoRAConfig
22
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.engine.arg_utils import EngineArgs
24
from vllm.engine.metrics_types import StatLoggerBase, Stats
25
26
27
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
28
29
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
30
from vllm.executor.executor_base import ExecutorBase
31
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
32
from vllm.inputs.parse import split_enc_dec_inputs
33
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.logger import init_logger
35
from vllm.logits_process import get_bad_words_logits_processors
36
from vllm.lora.request import LoRARequest
37
from vllm.model_executor.layers.sampler import SamplerOutput
38
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
39
from vllm.multimodal.cache import processor_only_cache_from_config
40
from vllm.multimodal.processing import EncDecMultiModalProcessor
41
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
42
                          RequestOutputFactory)
43
from vllm.sampling_params import RequestOutputKind, SamplingParams
44
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
45
46
47
                           Sequence, SequenceGroup, SequenceGroupBase,
                           SequenceGroupMetadata, SequenceGroupOutput,
                           SequenceStatus)
48
49
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
50
from vllm.transformers_utils.detokenizer import Detokenizer
51
from vllm.transformers_utils.tokenizer import AnyTokenizer
52
from vllm.transformers_utils.tokenizer_group import (
53
    TokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
54
55
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
56
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
57
from vllm.version import __version__ as VLLM_VERSION
58
from vllm.worker.model_runner_base import InputProcessingError
59
60

logger = init_logger(__name__)
61
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
62

63
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
64
_R = TypeVar("_R", default=Any)
65
66


67
68
69
70
71
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
72
73
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
74
75


76
77
78
79
80
81
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
82
83
84
85
86
87
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
88
89
90
    skip: List[int]


91
class SchedulerContext:
92

93
    def __init__(self) -> None:
94
        self.output_queue: Deque[OutputData] = deque()
95
        self.request_outputs: List[RequestOutput] = []
96
97
98
99
100
101
102
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
103
104
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
105
106
107
108
109
110
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
111
                       is_first_step_output=is_first_step_output,
112
                       skip=[]))
113
114


115
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
116
    """An LLM engine that receives requests and generates texts.
117

Woosuk Kwon's avatar
Woosuk Kwon committed
118
    This is the main class for the vLLM engine. It receives requests
119
120
121
122
123
124
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

125
126
127
    The [`LLM`][vllm.LLM] class wraps this class for offline batched inference
    and the [`AsyncLLMEngine`][vllm.engine.async_llm_engine.AsyncLLMEngine]
    class wraps this class for online serving.
128

129
    The config arguments are derived from [`EngineArgs`][vllm.EngineArgs].
130
131

    Args:
132
        vllm_config: The configuration for initializing and running vLLM.
133
134
        executor_class: The model executor class for managing distributed
            execution.
135
        log_stats: Whether to log statistics.
136
        usage_context: Specified entry point, used for usage info collection.
137
    """
138

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

164
        return cast(_O, output)
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

188
    tokenizer: Optional[TokenizerGroup]
189

190
191
    def __init__(
        self,
192
        vllm_config: VllmConfig,
193
        executor_class: Type[ExecutorBase],
194
        log_stats: bool,
yhu422's avatar
yhu422 committed
195
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
196
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
197
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
198
        use_cached_outputs: bool = False,
199
    ) -> None:
200
201
202
203
204
205
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
206

207
        self.vllm_config = vllm_config
208
209
210
211
212
213
214
215
216
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
217
        )
218
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
219
220
        )

221
        logger.info(
222
            "Initializing a V0 LLM engine (v%s) with config: %s, "
223
            "use_cached_outputs=%s, ",
224
            VLLM_VERSION,
225
            vllm_config,
226
            use_cached_outputs,
227
        )
228

229
        self.log_stats = log_stats
230
        self.use_cached_outputs = use_cached_outputs
231

232
        if self.model_config.skip_tokenizer_init:
233
            self.tokenizer = None
234
            self.detokenizer = None
235
            tokenizer_group = None
236
237
238
239
        else:
            self.tokenizer = self._init_tokenizer()
            self.detokenizer = Detokenizer(self.tokenizer)
            tokenizer_group = self.get_tokenizer_group()
240
241
242
243
244
245
246

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
247

248
        self.seq_counter = Counter()
249
250
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
251

252
253
254
255
256
257
258
        self.input_preprocessor = InputPreprocessor(
            self.model_config,
            self.tokenizer,
            mm_registry,
            mm_processor_cache=processor_only_cache_from_config(
                self.model_config, mm_registry),
        )
259

260
        self.model_executor = executor_class(vllm_config=vllm_config)
261

262
        self._initialize_kv_caches()
263

yhu422's avatar
yhu422 committed
264
265
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
266
267
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
268
            usage_message.report_usage(
269
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
270
271
272
273
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
274
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
275
                    "tensor_parallel_size":
276
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
277
                    "block_size":
278
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
279
                    "gpu_memory_utilization":
280
                    self.cache_config.gpu_memory_utilization,
281
282
                    "kv_cache_memory_bytes":
                    self.cache_config.kv_cache_memory_bytes,
yhu422's avatar
yhu422 committed
283
284
                    # Quantization
                    "quantization":
285
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
286
                    "kv_cache_dtype":
287
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
288
289
290

                    # Feature flags
                    "enable_lora":
291
                    bool(self.lora_config),
yhu422's avatar
yhu422 committed
292
                    "enable_prefix_caching":
293
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
294
                    "enforce_eager":
295
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
296
                    "disable_custom_all_reduce":
297
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
298
299
                })

300
301
302
303
304
305
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
306
            SchedulerContext()
307
308
309
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

310
        if self.model_config.use_async_output_proc:
311
312
313
314
315
316
317
318
319
            process_model_outputs = weak_bind(self._process_model_outputs)

            self.async_callbacks = [
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
        else:
            self.async_callbacks = []
320
321
322

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
323
        self.process_request_outputs_callback: Optional[Callable] = None
324

325
        # Create the scheduler.
326
327
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
328
329
330
331
332
        if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
            Scheduler = resolve_obj_by_qualname(
                self.vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = self.vllm_config.scheduler_config.scheduler_cls
333
        self.scheduler = [
334
            Scheduler(
335
336
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
337
                self.async_callbacks[v_id]
338
339
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
340
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
341

342
343
        # Metric Logging.
        if self.log_stats:
344
345
346
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
347
348
349
350
351
352
353
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

354
355
356
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
357
358
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        vllm_config=vllm_config),
359
360
361
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
362
363
                        labels=dict(
                            model_name=self.model_config.served_model_name),
364
                        vllm_config=vllm_config),
365
366
367
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
368

369
370
371
372
373
374
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

375
376
377
378
379
380
381
382
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
383
                get_tokenizer_for_seq,
384
385
                stop_checker=StopChecker(self.scheduler_config.max_model_len,
                                         get_tokenizer_for_seq),
386
387
            ))

388
389
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

390
391
392
393
        # Flag to set when an input fails to process and the engine should run
        # the next step without re-scheduling.
        self._skip_scheduling_next_step = False

394
395
396
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

397
398
399
400
401
402
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
403
        start = time.time()
404
405
406
407
408
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
409
410
411
412
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
413
414
415
416
417
418
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
419
420
421
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
422

423
    @classmethod
424
    def _get_executor_cls(cls,
425
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
426
        # distributed_executor_backend must be set in VllmConfig.__post_init__
427
428
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
429
        # Initialize the cluster and specify the executor class.
430
431
432
433
434
435
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
436
437
438
439
440
441
442
443
444
445
446
447
448
        elif distributed_executor_backend == "ray":
            from vllm.executor.ray_distributed_executor import (
                RayDistributedExecutor)
            executor_class = RayDistributedExecutor
        elif distributed_executor_backend == "mp":
            from vllm.executor.mp_distributed_executor import (
                MultiprocessingDistributedExecutor)
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
            executor_class = MultiprocessingDistributedExecutor
        elif distributed_executor_backend == "uni":
            # JAX-style, single-process, multi-device executor.
449
450
            from vllm.executor.uniproc_executor import UniProcExecutor
            executor_class = UniProcExecutor
451
452
453
454
455
456
457
458
        elif distributed_executor_backend == "external_launcher":
            # executor with external launcher
            from vllm.executor.uniproc_executor import (  # noqa
                ExecutorWithExternalLauncher)
            executor_class = ExecutorWithExternalLauncher
        else:
            raise ValueError("unrecognized distributed_executor_backend: "
                             f"{distributed_executor_backend}")
459
460
        return executor_class

461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

477
478
479
480
481
482
483
484
485
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
486
487
488
489
490
491
492
493
494
        vllm_config = engine_args.create_engine_config(usage_context)

        engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            engine_cls = V1LLMEngine

        return engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
495
            usage_context=usage_context,
496
            stat_loggers=stat_loggers,
497
            disable_log_stats=engine_args.disable_log_stats,
yhu422's avatar
yhu422 committed
498
        )
499

500
501
502
503
504
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

505
506
507
508
509
510
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

511
512
    def get_tokenizer_group(self) -> TokenizerGroup:
        if self.tokenizer is None:
513
514
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
515

516
        return self.tokenizer
517

518
    def get_tokenizer(
519
520
521
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
522
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
523

524
    def _init_tokenizer(self) -> TokenizerGroup:
525
526
527
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
528
            lora_config=self.lora_config)
529

530
531
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
532
        self.cache_config.verify_with_parallel_config(self.parallel_config)
533
534
535
536
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
537

538
539
540
    def _add_processed_request(
        self,
        request_id: str,
541
        processed_inputs: ProcessorInputs,
542
        params: SamplingParams,
543
544
        arrival_time: float,
        lora_request: Optional[LoRARequest],
545
        trace_headers: Optional[Mapping[str, str]] = None,
546
        priority: int = 0,
547
    ) -> Optional[SequenceGroup]:
548
549
550
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
551
552
553
554
555
556
557
558
559
560
561
562
563
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
            )
            return None

564
        self._validate_model_inputs(processed_inputs, lora_request)
565
566
567
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
568
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
569

570
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
571
572

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
573
                       lora_request)
574

575
        encoder_seq = (None if encoder_inputs is None else Sequence(
576
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
577

578
        # Create a SequenceGroup based on SamplingParams
579
580
581
582
583
584
585
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
586
                trace_headers=trace_headers,
587
588
                encoder_seq=encoder_seq,
                priority=priority)
589
        else:
590
            raise ValueError("SamplingParams must be provided.")
591

592
593
594
595
596
597
598
599
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

600
601
        return seq_group

602
603
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
604

605
606
607
    def add_request(
        self,
        request_id: str,
608
        prompt: PromptType,
609
        params: SamplingParams,
610
        arrival_time: Optional[float] = None,
611
        lora_request: Optional[LoRARequest] = None,
612
        tokenization_kwargs: Optional[dict[str, Any]] = None,
613
        trace_headers: Optional[Mapping[str, str]] = None,
614
        priority: int = 0,
615
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
616
        """Add a request to the engine's request pool.
617
618

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
619
        scheduler as `engine.step()` is called. The exact scheduling policy is
620
621
622
623
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
624
625
            prompt: The prompt to the LLM. See
                [PromptType][vllm.inputs.PromptType]
626
                for more details about the format of each input.
627
            params: Parameters for sampling.
628
                [SamplingParams][vllm.SamplingParams] for text generation.
629
            arrival_time: The arrival time of the request. If None, we use
630
                the current monotonic time.
631
            lora_request: The LoRA request to add.
632
            trace_headers: OpenTelemetry trace headers.
633
634
            priority: The priority of the request.
                Only applicable with priority scheduling.
635
636
637
638

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
639
640
641
642
            - Create `n` number of [Sequence][vllm.sequence.Sequence] objects.
            - Create a [SequenceGroup][vllm.sequence.SequenceGroup] object
              from the list of [Sequence][vllm.sequence.Sequence].
            - Add the [SequenceGroup][vllm.sequence.SequenceGroup] object to the
643
              scheduler.
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
660
        """
661
662
663
664
        if not isinstance(request_id, str):
            raise TypeError(
                f"request_id must be a string, got {type(request_id)}")

665
666
667
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
668

669
        if priority != 0 and not self.scheduler_config.policy == "priority":
670
671
672
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

673
        if isinstance(params, SamplingParams) \
674
            and params.logits_processors:
675
            raise ValueError(
676
                "Logits processors are not supported in multi-step decoding")
677

678
        if arrival_time is None:
679
            arrival_time = time.time()
680

681
682
683
684
685
686
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            seq_len = prompt["prompt_embeds"].shape[0]
            prompt["prompt_token_ids"] = [0] * seq_len

687
        processed_inputs = self.input_preprocessor.preprocess(
688
            prompt,
689
            tokenization_kwargs=tokenization_kwargs,
690
            lora_request=lora_request,
691
        )
692

693
        self._add_processed_request(
694
695
696
697
698
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
699
            trace_headers=trace_headers,
700
            priority=priority,
701
        )
702
703
704
705
706
707

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
708
709
        arrival_time: float,
        lora_request: Optional[LoRARequest],
710
        trace_headers: Optional[Mapping[str, str]] = None,
711
        encoder_seq: Optional[Sequence] = None,
712
        priority: int = 0,
713
714
715
716
717
718
719
720
721
722
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

723
724
725
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

726
727
728
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
729

730
        sampling_params.update_from_generation_config(
731
            self.generation_config_fields, seq.eos_token_id)
732

733
        # Create the sequence group.
734
735
736
737
        draft_size = 1
        if self.vllm_config.speculative_config is not None:
            draft_size = \
                self.vllm_config.speculative_config.num_speculative_tokens + 1
738
739
740
741
742
743
744
745
746
        seq_group = SequenceGroup(request_id=request_id,
                                  seqs=[seq],
                                  arrival_time=arrival_time,
                                  sampling_params=sampling_params,
                                  lora_request=lora_request,
                                  trace_headers=trace_headers,
                                  encoder_seq=encoder_seq,
                                  priority=priority,
                                  draft_size=draft_size)
747

748
749
        return seq_group

Antoni Baum's avatar
Antoni Baum committed
750
751
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
752
753

        Args:
Antoni Baum's avatar
Antoni Baum committed
754
            request_id: The ID(s) of the request to abort.
755
756

        Details:
757
            - Refer to [vllm.core.scheduler.Scheduler.abort_seq_group][].
758
759
760
761
762
763

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
764
        """
765
        for scheduler in self.scheduler:
766
767
            scheduler.abort_seq_group(
                request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
768

769
770
771
772
    def get_vllm_config(self) -> VllmConfig:
        """Gets the vllm configuration."""
        return self.vllm_config

773
774
775
776
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

777
778
779
780
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

781
782
783
784
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

785
786
787
788
789
790
791
792
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

793
    def get_num_unfinished_requests(self) -> int:
794
        """Gets the number of unfinished requests."""
795
796
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
797

798
    def has_unfinished_requests(self) -> bool:
799
        """Returns True if there are unfinished requests."""
800
801
802
803
804
805
806
807
808
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
809

810
811
    def reset_mm_cache(self) -> bool:
        """Reset the multi-modal cache."""
812
813
        self.input_preprocessor.clear_cache()
        return True
814

815
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
816
817
818
819
        """Reset prefix cache for all devices."""

        success = True
        for scheduler in self.scheduler:
820
            success = success and scheduler.reset_prefix_cache(device)
821
822
        return success

823
824
825
826
827
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
828

829
830
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
831
        """
832

833
        now = time.time()
834

835
        if len(ctx.output_queue) == 0:
836
837
            return None

838
        # Get pending async postprocessor
839
840
841
842
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
843
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
844
845
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
846
847
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
848
849
850
851
852

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

853
        has_multiple_outputs: bool = len(outputs) > 1
854
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
855
856
        assert not has_multiple_outputs
        outputs_by_sequence_group = outputs
857

858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

875
        finished_before: List[int] = []
876
        finished_now: List[int] = []
877
878
879
880
881
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
882
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
883

884
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
885
886
887
888
889

            if seq_group.is_finished():
                finished_before.append(i)
                continue

890
            output: List[SequenceGroupOutput]
891
            if has_multiple_outputs:
892
893
894
895
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

896
            if not is_async:
897
898
                seq_group.update_num_computed_tokens(
                    seq_group_meta.token_chunk_size or 0)
899
900
901

            if outputs:
                for o in outputs:
902
903
904
905
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
906
                                o.model_forward_time or 0)
907
908
909
910
911
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
912
                                o.model_execute_time or 0)
913
914
915
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
916

917
918
919
920
            self.output_processor.process_prompt_logprob(seq_group, output)
            if seq_group_meta.do_sample:
                self.output_processor.process_outputs(seq_group, output,
                                                      is_async)
921

922
923
            if seq_group.is_finished():
                finished_now.append(i)
924

925
926
927
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
928

929
930
            seq_group = scheduled_seq_group.seq_group
            seq_group.maybe_set_first_token_time(now)
931
932
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
933
            request_output = RequestOutputFactory.create(
934
935
936
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
937
938
            if request_output:
                ctx.request_outputs.append(request_output)
939

940
941
942
943
944
945
946
947
948
949
950
951
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

952
953
954
955
956
957
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

        # Create the outputs
958
959
        for i in indices:
            if i in skip or i in finished_before or i in finished_now:
960
961
                continue  # Avoids double processing

962
963
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

964
            seq_group = scheduled_seq_group.seq_group
965
            seq_group.maybe_set_first_token_time(now)
966
967
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
968
            request_output = RequestOutputFactory.create(
969
970
971
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
972
            if request_output:
973
                ctx.request_outputs.append(request_output)
974

975
        # Create outputs only after processing the scheduler's results
976

977
        for seq_group in scheduler_outputs.ignored_seq_groups:
978
979
980
981
982
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

983
            request_output = RequestOutputFactory.create(
984
985
986
987
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
988
989
            if request_output:
                ctx.request_outputs.append(request_output)
990

991
992
993
994
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
995
            ctx.request_outputs.clear()
996

997
998
999
1000
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1001
            # Log stats.
1002
1003
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1004
1005

            # Tracing
1006
            self.do_tracing(scheduler_outputs, finished_before)
1007
1008
1009
1010

        return None

    def _advance_to_next_step(
1011
            self, output: SamplerOutput,
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1025
1026
1027
1028
            token_chunk_size = (seq_group_metadata.token_chunk_size
                                if seq_group_metadata.token_chunk_size
                                is not None else 0)
            seq_group.update_num_computed_tokens(token_chunk_size)
1029

1030
1031
1032
            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1033
                    " (i.e sampling_params.n == 1)")
1034
1035
1036
1037
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1038

1039
1040
                seq.append_token_id(sample.output_token, sample.logprobs,
                                    sample.output_embed)
1041

1042
    def step(self) -> List[RequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
1043
1044
        """Performs one decoding iteration and returns newly generated results.

1045
1046
1047
1048
        <figure markdown="span">
        ![Overview of the step function](https://i.imgur.com/sv2HssD.png)
        <figcaption>Overview of the step function</figcaption>
        </figure>
1049
1050

        Details:
1051
1052
        - Step 1: Schedules the sequences to be executed in the next
            iteration and the token blocks to be swapped in/out/copy.
1053

1054
1055
1056
1057
            - Depending on the scheduling policy,
                sequences may be `preempted/reordered`.
            - A Sequence Group (SG) refer to a group of sequences
                that are generated from the same prompt.
1058

1059
1060
        - Step 2: Calls the distributed executor to execute the model.
        - Step 3: Processes the model output. This mainly includes:
1061

1062
1063
1064
1065
            - Decodes the relevant outputs.
            - Updates the scheduled sequence groups with model outputs
                based on its `sampling parameters` (`use_beam_search` or not).
            - Frees the finished sequence groups.
1066

1067
        - Finally, it creates and returns the newly generated results.
1068
1069

        Example:
1070
1071
1072
1073
1074
1075
1076
        ```
        # Please see the example/ folder for more detailed examples.

        # initialize engine and request arguments
        engine = LLMEngine.from_engine_args(engine_args)
        example_inputs = [(0, "What is LLM?",
        SamplingParams(temperature=0.0))]
1077

1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
        # Start the engine with an event loop
        while True:
            if example_inputs:
                req_id, prompt, sampling_params = example_inputs.pop(0)
                engine.add_request(str(req_id),prompt,sampling_params)

            # continue the request processing
            request_outputs = engine.step()
            for request_output in request_outputs:
                if request_output.finished:
                    # return or show the request output

            if not (engine.has_unfinished_requests() or example_inputs):
                break
        ```
Antoni Baum's avatar
Antoni Baum committed
1093
        """
1094
1095
1096
1097
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1098

1099
        # For llm_engine, there is no pipeline parallel support, so the engine
1100
        # used is always 0.
1101
1102
        virtual_engine = 0

1103
1104
        # These are cached outputs from previous iterations. None if on first
        # iteration
1105
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1106
1107
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1108
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1109

1110
1111
        ctx = self.scheduler_contexts[virtual_engine]

1112
1113
1114
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1115
1116
1117
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
1118
1119
1120
1121
1122
        # The scheduler is also skipped if a single request caused the last
        # engine step to fail, and the previous schedule needs to be rerun.
        if not self._has_remaining_steps(
                seq_group_metadata_list
        ) and not self._skip_scheduling_next_step:
1123
            # Schedule iteration
1124
            (seq_group_metadata_list, scheduler_outputs,
1125
1126
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1127

1128
1129
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1130

1131
1132
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
1133
1134
1135
1136
1137
            # When n>1, elements in self.seq_id_to_seq_group should be deleted
            # here, otherwise memory leaks.
            for finished_request_id in finished_requests_ids:
                if finished_request_id in self.seq_id_to_seq_group:
                    del self.seq_id_to_seq_group[finished_request_id]
1138

1139
1140
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1141
                self._process_model_outputs(ctx=ctx)
1142

1143
1144
        else:
            finished_requests_ids = list()
1145
1146
1147

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1148

1149
        if not scheduler_outputs.is_empty():
1150
1151
1152
1153
1154
1155

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1156
                self._get_last_sampled_token_ids(virtual_engine)
1157

1158
            execute_model_req = ExecuteModelRequest(
1159
1160
1161
1162
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1163
1164
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1165
1166
1167
1168
1169
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1170
            if allow_async_output_proc:
1171
1172
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1173

1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
            try:
                outputs = self.model_executor.execute_model(
                    execute_model_req=execute_model_req)
                self._skip_scheduling_next_step = False
            except InputProcessingError as e:
                # The input for this request cannot be processed, so we must
                # abort it. If there are remaining requests in the batch that
                # have been scheduled, they will be retried on the next step.
                invalid_request_id = e.request_id
                self._abort_and_cache_schedule(
                    request_id=invalid_request_id,
                    virtual_engine=virtual_engine,
                    seq_group_metadata_list=seq_group_metadata_list,
                    scheduler_outputs=scheduler_outputs,
                    allow_async_output_proc=allow_async_output_proc)
                # Raise so the caller is notified that this request failed
                raise
1191

1192
        else:
1193
1194
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1195
1196
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1197
            # No outputs in this case
1198
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1199

1200
        if not self._has_remaining_steps(seq_group_metadata_list):
1201
            # is_first_step_output is True only when the num_steps of all
1202
            # the sequences are 1.
1203
1204
1205
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1206
            # Add results to the output_queue
1207
1208
1209
1210
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1211
1212
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1213
1214
1215

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1216
                    "Async postprocessor expects only a single output set")
1217

1218
                self._advance_to_next_step(
1219
                    outputs[0], seq_group_metadata_list,
1220
                    scheduler_outputs.scheduled_seq_groups)
1221

1222
            # Check if need to run the usual non-async path
1223
            if not allow_async_output_proc:
1224
                self._process_model_outputs(ctx=ctx)
1225

1226
                # Log stats.
1227
                self.do_log_stats(scheduler_outputs, outputs)
1228

1229
1230
1231
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1232
            # Multi-step case
1233
            return ctx.request_outputs
1234

1235
        if not self.has_unfinished_requests():
1236
1237
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1238
                self._process_model_outputs(ctx=ctx)
1239
            assert len(ctx.output_queue) == 0
1240

1241
1242
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
1243
            # torch.distributed ops which may otherwise time out, and unblocks
1244
1245
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1246
            logger.debug("Stopping remote worker execution loop.")
1247
1248
            self.model_executor.stop_remote_worker_execution_loop()

1249
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1250

1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
    def _abort_and_cache_schedule(
            self, request_id: str, virtual_engine: int,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        """Aborts a single request, and caches the scheduler outputs minus that
        request. This allows the next step to continue processing the remaining
        requests without having to re-run the scheduler."""

        # Abort the request and remove its sequence group from the current
        # schedule
        self.abort_request(request_id)
        for i, metadata in enumerate(seq_group_metadata_list):
            if metadata.request_id == request_id:
                del seq_group_metadata_list[i]
                break
        for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
            if group.seq_group.request_id == request_id:
                del scheduler_outputs.scheduled_seq_groups[i]
                break

        # If there are still other sequence groups left in the schedule, cache
        # them and flag the engine to reuse the schedule.
        if len(seq_group_metadata_list) > 0:
            self._skip_scheduling_next_step = True
            # Reuse multi-step caching logic
            self._cache_scheduler_outputs_for_multi_step(
                virtual_engine=virtual_engine,
                scheduler_outputs=scheduler_outputs,
                seq_group_metadata_list=seq_group_metadata_list,
                allow_async_output_proc=allow_async_output_proc)

1283
1284
1285
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
1286
        return False
1287
1288
1289
1290

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1291
1292
1293
1294
1295
1296
1297
1298
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        return None

1317
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1318
1319
1320
1321
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1322
1323
1324
1325
1326
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1327
1328
1329
1330
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1331
1332
1333
1334
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1335
1336
1337
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1338
1339
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1340
1341
        """Forced log when no requests active."""
        if self.log_stats:
1342
            stats = self._get_stats(scheduler_outputs, model_output,
1343
                                    finished_before, skip)
1344
            for logger in self.stat_loggers.values():
1345
                logger.log(stats)
1346

1347
1348
1349
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1350
1351
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1352
1353
1354
1355
1356
1357
1358
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1359
1360
1361
1362
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1363
        """
1364
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1365

1366
1367
        # System State
        #   Scheduler State
1368
1369
1370
1371
1372
1373
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1374
1375

        # KV Cache Usage in %
1376
        num_total_gpu = self.cache_config.num_gpu_blocks
1377
        gpu_cache_usage_sys = 0.
1378
        if num_total_gpu:  # Guard against both None and 0
1379
1380
1381
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1382
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1383

1384
        num_total_cpu = self.cache_config.num_cpu_blocks
1385
        cpu_cache_usage_sys = 0.
1386
        if num_total_cpu:  # Guard against both None and 0
1387
1388
1389
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1390
1391
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1392
1393
1394
1395
1396
1397
1398
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
        # Exchange the uasge and cache hit stats between gpu and cpu when
        # running on cpu because the cpu_worker.py intentionally reports the
        # number of cpu blocks as gpu blocks in favor of cache management.
        if self.device_config.device_type == "cpu":
            num_total_gpu, num_total_cpu = num_total_cpu, num_total_gpu
            gpu_cache_usage_sys, cpu_cache_usage_sys = (
                cpu_cache_usage_sys,
                gpu_cache_usage_sys,
            )
            gpu_prefix_cache_hit_rate, cpu_prefix_cache_hit_rate = (
                cpu_prefix_cache_hit_rate,
                gpu_prefix_cache_hit_rate,
            )

1413
1414
1415
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1416
        num_tokens_iter = 0
1417
        time_to_first_tokens_iter: List[float] = []
1418
        inter_token_latencies_iter: List[float] = []
1419
1420
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1421
1422
1423
1424

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1425
1426
1427
1428
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1429
1430
1431
1432
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1433
        max_num_generation_tokens_requests: List[int] = []
1434
        max_tokens_requests: List[int] = []
1435
1436
        finished_reason_requests: List[str] = []

1437
        # LoRA requests
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1456
1457
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1458
        if scheduler_outputs is not None:
1459
1460
1461
1462
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1463
            num_generation_tokens_from_prefill_groups = 0
1464
1465
1466
1467
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1468
1469
1470

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1471
1472
1473
1474
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue
1475
1476
1477
1478
1479

                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1480

1481
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1482
                seq_group = scheduled_seq_group.seq_group
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
1495
                        latency = seq_group.get_last_token_latency()
1496
1497
1498
1499
1500
1501
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
1502
                    # ITLs
1503
                    latency = seq_group.get_last_token_latency()
1504
                    inter_token_latencies_iter.append(latency)
1505
1506
1507
1508
1509
1510
1511
1512
1513
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1514
1515
1516
1517
1518
1519

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1520
                if seq_group.is_finished():
1521
                    # Latency timings
1522
1523
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1536
1537
1538
1539
1540
1541
1542
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1543
1544
1545
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1546
1547
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1548
1549
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1562
                actual_num_batched_tokens - num_prompt_tokens_iter +
1563
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1564
1565
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1566

1567
1568
        return Stats(
            now=now,
1569
1570
1571
1572
1573
1574
1575
1576
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1577
1578
1579
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1580
1581
1582
1583

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1584
            num_tokens_iter=num_tokens_iter,
1585
            time_to_first_tokens_iter=time_to_first_tokens_iter,
1586
            inter_token_latencies_iter=inter_token_latencies_iter,
1587
            num_preemption_iter=num_preemption_iter,
1588
1589
1590
1591

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1592
1593
1594
1595
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1596
1597
1598
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1599
1600
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
1601
            n_requests=n_requests,
1602
            max_tokens_requests=max_tokens_requests,
1603
            finished_reason_requests=finished_reason_requests,
1604
1605
1606
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1607

1608
    def add_lora(self, lora_request: LoRARequest) -> bool:
1609
        return self.model_executor.add_lora(lora_request)
1610
1611

    def remove_lora(self, lora_id: int) -> bool:
1612
        return self.model_executor.remove_lora(lora_id)
1613

1614
    def list_loras(self) -> Set[int]:
1615
        return self.model_executor.list_loras()
1616

1617
1618
1619
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1620
1621
1622
1623
1624
1625
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

1626
1627
1628
1629
1630
    def sleep(self, level: int = 1) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.sleep(level=level)

1631
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
1632
1633
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
1634
        self.model_executor.wake_up(tags)
1635

1636
1637
1638
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

1639
    def check_health(self) -> None:
1640
        self.model_executor.check_health()
1641
1642
1643
1644

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

1645
1646
1647
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
1648
1649
1650
        if self.tracer is None:
            return

1651
1652
1653
1654
1655
1656
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
1674
1675
1676
1677
1678
1679
1680
1681

            # Handle potential None values for cancelled/aborted requests
            ttft = (metrics.first_token_time - metrics.arrival_time
                    if metrics.first_token_time is not None else None)

            e2e_time = (metrics.finished_time - metrics.arrival_time
                        if metrics.finished_time is not None else None)

1682
            seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
1683
                                   self.model_config.model)
1684
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
1685
                                   seq_group.request_id)
1686
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
1687
                                   seq_group.sampling_params.temperature)
1688
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
1689
                                   seq_group.sampling_params.top_p)
1690
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
1691
                                   seq_group.sampling_params.max_tokens)
1692
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
1693
                                   seq_group.sampling_params.n)
1694
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
1695
                                   seq_group.num_seqs())
1696
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
1697
1698
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
1699
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
1700
1701
1702
1703
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715

            # Only set timing attributes if the values are available
            if metrics.time_in_queue is not None:
                seq_span.set_attribute(
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
                    metrics.time_in_queue)
            if ttft is not None:
                seq_span.set_attribute(
                    SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            if e2e_time is not None:
                seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E,
                                       e2e_time)
1716
1717
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
1718
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
1719
1720
1721
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
1722
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
1723
1724
1725
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
1726
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
1727
                    metrics.model_execute_time)
1728

1729
    def _validate_model_inputs(self, inputs: ProcessorInputs,
1730
                               lora_request: Optional[LoRARequest]):
1731
1732
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

1733
1734
1735
1736
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs,
                                       lora_request,
                                       prompt_type="encoder")
1737

1738
1739
1740
        self._validate_model_input(decoder_inputs,
                                   lora_request,
                                   prompt_type="decoder")
1741

1742
1743
1744
1745
1746
1747
1748
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        lora_request: Optional[LoRARequest],
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
1749
1750
1751
        model_config = self.model_config
        tokenizer = (None if self.tokenizer is None else
                     self.tokenizer.get_lora_tokenizer(lora_request))
1752

1753
        prompt_ids = prompt_inputs.get("prompt_token_ids", [])
1754
1755
1756
        if not prompt_ids:
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
1757
            elif prompt_inputs["type"] == "embeds":
1758
                pass
1759
1760
1761
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")

1762
1763
1764
1765
1766
1767
        if tokenizer is not None:
            max_input_id = max(prompt_ids, default=0)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    f"Token id {max_input_id} is out of vocabulary")

1768
        max_prompt_len = self.model_config.max_model_len
1769
        if len(prompt_ids) > max_prompt_len:
1770
            if prompt_type == "encoder" and model_config.is_multimodal_model:
1771
1772
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
1773
1774
1775
                    model_config,
                    tokenizer=tokenizer or object(),  # Dummy if no tokenizer
                )
1776
                assert isinstance(mm_processor, EncDecMultiModalProcessor)
1777

1778
                if mm_processor.pad_dummy_encoder_prompt:
汪志鹏's avatar
汪志鹏 committed
1779
                    return  # Skip encoder length check for Whisper and Donut
1780

1781
            if model_config.is_multimodal_model:
1782
                suggestion = (
1783
1784
1785
1786
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
1787
1788
1789
1790
1791
1792
1793
1794
1795
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
1796
1797
1798
1799

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
1800
1801
1802
1803

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
1804
1805
1806
1807
        """Constructs logits processors based on the logits_bias, and
        allowed_token_ids fields in sampling_params. Deletes those fields and
        adds the constructed logits processors to the logits_processors field.
        Returns the modified sampling params."""
1808
1809

        logits_processors = []
1810

1811
1812
1813
        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

1814
            processors = get_openai_logits_processors(
1815
1816
1817
1818
1819
1820
1821
1822
1823
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

1824
1825
1826
1827
1828
1829
        if len(sampling_params.bad_words) > 0:
            tokenizer = self.get_tokenizer(lora_request)
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

1830
1831
1832
1833
1834
1835
1836
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params
1837

1838
1839
1840
1841
1842
1843
1844
1845
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

1846

1847
1848
1849
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
    LLMEngine = V1LLMEngine  # type: ignore