llm_engine.py 79.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Antoni Baum's avatar
Antoni Baum committed
4
import time
5
from collections import Counter as collectionsCounter
6
from collections import deque
7
from contextlib import contextmanager
8
from dataclasses import dataclass
9
from functools import partial
10
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
11
                    Iterable, List, Literal, Mapping, NamedTuple, Optional)
12
from typing import Sequence as GenericSequence
13
from typing import Set, Type, Union, cast
14

15
import torch
16
from typing_extensions import TypeVar
17

18
import vllm.envs as envs
19
20
21
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ObservabilityConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
22
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.engine.arg_utils import EngineArgs
24
from vllm.engine.metrics_types import StatLoggerBase, Stats
25
26
27
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
28
29
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
30
from vllm.executor.executor_base import ExecutorBase
31
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
32
from vllm.inputs.parse import split_enc_dec_inputs
33
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.logger import init_logger
35
from vllm.logits_process import get_bad_words_logits_processors
36
from vllm.lora.request import LoRARequest
37
from vllm.model_executor.layers.sampler import SamplerOutput
38
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
39
from vllm.multimodal.processing import EncDecMultiModalProcessor
40
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
41
42
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
43
from vllm.sampling_params import RequestOutputKind, SamplingParams
44
45
46
47
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
                           PoolingSequenceGroupOutput, Sequence, SequenceGroup,
                           SequenceGroupBase, SequenceGroupMetadata,
                           SequenceGroupOutput, SequenceStatus)
48
49
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
50
from vllm.transformers_utils.detokenizer import Detokenizer
51
from vllm.transformers_utils.tokenizer import AnyTokenizer
52
from vllm.transformers_utils.tokenizer_group import (
53
    TokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
54
55
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
56
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
57
from vllm.version import __version__ as VLLM_VERSION
58
from vllm.worker.model_runner_base import InputProcessingError
59
60

logger = init_logger(__name__)
61
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
62

63
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
64
_R = TypeVar("_R", default=Any)
65
66


67
68
69
70
71
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
72
73
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
74
75


76
77
78
79
80
81
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
82
83
84
85
86
87
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
88
89
90
    skip: List[int]


91
class SchedulerContext:
92

93
    def __init__(self) -> None:
94
95
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
96
                                         PoolingRequestOutput]] = []
97
98
99
100
101
102
103
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
104
105
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
106
107
108
109
110
111
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
112
                       is_first_step_output=is_first_step_output,
113
                       skip=[]))
114
115


116
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
117
    """An LLM engine that receives requests and generates texts.
118

Woosuk Kwon's avatar
Woosuk Kwon committed
119
    This is the main class for the vLLM engine. It receives requests
120
121
122
123
124
125
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

126
127
128
    The [`LLM`][vllm.LLM] class wraps this class for offline batched inference
    and the [`AsyncLLMEngine`][vllm.engine.async_llm_engine.AsyncLLMEngine]
    class wraps this class for online serving.
129

130
    The config arguments are derived from [`EngineArgs`][vllm.EngineArgs].
131
132

    Args:
133
        vllm_config: The configuration for initializing and running vLLM.
134
135
        executor_class: The model executor class for managing distributed
            execution.
136
        log_stats: Whether to log statistics.
137
        usage_context: Specified entry point, used for usage info collection.
138
    """
139

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

165
        return cast(_O, output)
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

189
    tokenizer: Optional[TokenizerGroup]
190

191
192
    def __init__(
        self,
193
        vllm_config: VllmConfig,
194
        executor_class: Type[ExecutorBase],
195
        log_stats: bool,
yhu422's avatar
yhu422 committed
196
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
197
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
198
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
199
        use_cached_outputs: bool = False,
200
    ) -> None:
201
202
203
204
205
206
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
207

208
        self.vllm_config = vllm_config
209
210
211
212
213
214
215
216
217
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
218
        )
219
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
220
221
        )

222
        logger.info(
223
            "Initializing a V0 LLM engine (v%s) with config: %s, "
224
            "use_cached_outputs=%s, ",
225
            VLLM_VERSION,
226
            vllm_config,
227
            use_cached_outputs,
228
        )
229

230
        self.log_stats = log_stats
231
        self.use_cached_outputs = use_cached_outputs
232

233
        if self.model_config.skip_tokenizer_init:
234
            self.tokenizer = None
235
            self.detokenizer = None
236
            tokenizer_group = None
237
238
239
240
        else:
            self.tokenizer = self._init_tokenizer()
            self.detokenizer = Detokenizer(self.tokenizer)
            tokenizer_group = self.get_tokenizer_group()
241
242
243
244
245
246
247

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
248

249
        self.seq_counter = Counter()
250
251
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
252

253
        self.input_preprocessor = InputPreprocessor(self.model_config,
254
255
                                                    self.tokenizer,
                                                    mm_registry)
256

257
        self.model_executor = executor_class(vllm_config=vllm_config)
258

259
        if self.model_config.runner_type != "pooling":
260
            self._initialize_kv_caches()
261

yhu422's avatar
yhu422 committed
262
263
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
264
265
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
266
            usage_message.report_usage(
267
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
268
269
270
271
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
272
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
273
                    "tensor_parallel_size":
274
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
275
                    "block_size":
276
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
277
                    "gpu_memory_utilization":
278
                    self.cache_config.gpu_memory_utilization,
yhu422's avatar
yhu422 committed
279
280
281

                    # Quantization
                    "quantization":
282
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
283
                    "kv_cache_dtype":
284
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
285
286
287

                    # Feature flags
                    "enable_lora":
288
                    bool(self.lora_config),
yhu422's avatar
yhu422 committed
289
                    "enable_prefix_caching":
290
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
291
                    "enforce_eager":
292
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
293
                    "disable_custom_all_reduce":
294
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
295
296
                })

297
298
299
300
301
302
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
303
            SchedulerContext()
304
305
306
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

307
        if self.model_config.use_async_output_proc:
308
309
310
311
312
313
314
315
316
            process_model_outputs = weak_bind(self._process_model_outputs)

            self.async_callbacks = [
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
        else:
            self.async_callbacks = []
317
318
319

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
320
        self.process_request_outputs_callback: Optional[Callable] = None
321

322
        # Create the scheduler.
323
324
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
325
326
327
328
329
        if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
            Scheduler = resolve_obj_by_qualname(
                self.vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = self.vllm_config.scheduler_config.scheduler_cls
330
        self.scheduler = [
331
            Scheduler(
332
333
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
334
                self.async_callbacks[v_id]
335
336
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
337
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
338

339
340
        # Metric Logging.
        if self.log_stats:
341
342
343
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
344
345
346
347
348
349
350
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

351
352
353
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
354
355
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        vllm_config=vllm_config),
356
357
358
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
359
360
                        labels=dict(
                            model_name=self.model_config.served_model_name),
361
                        vllm_config=vllm_config),
362
363
364
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
365

366
367
368
369
370
371
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

372
373
374
375
376
377
378
379
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
380
                get_tokenizer_for_seq,
381
382
                stop_checker=StopChecker(self.scheduler_config.max_model_len,
                                         get_tokenizer_for_seq),
383
384
            ))

385
386
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

387
388
389
390
        # Flag to set when an input fails to process and the engine should run
        # the next step without re-scheduling.
        self._skip_scheduling_next_step = False

391
392
393
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

394
395
396
397
398
399
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
400
        start = time.time()
401
402
403
404
405
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
406
407
408
409
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
410
411
412
413
414
415
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
416
417
418
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
419

420
    @classmethod
421
    def _get_executor_cls(cls,
422
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
423
        # distributed_executor_backend must be set in VllmConfig.__post_init__
424
425
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
426
        # Initialize the cluster and specify the executor class.
427
428
429
430
431
432
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
433
434
435
436
437
438
439
440
441
442
443
444
445
        elif distributed_executor_backend == "ray":
            from vllm.executor.ray_distributed_executor import (
                RayDistributedExecutor)
            executor_class = RayDistributedExecutor
        elif distributed_executor_backend == "mp":
            from vllm.executor.mp_distributed_executor import (
                MultiprocessingDistributedExecutor)
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
            executor_class = MultiprocessingDistributedExecutor
        elif distributed_executor_backend == "uni":
            # JAX-style, single-process, multi-device executor.
446
447
            from vllm.executor.uniproc_executor import UniProcExecutor
            executor_class = UniProcExecutor
448
449
450
451
452
453
454
455
        elif distributed_executor_backend == "external_launcher":
            # executor with external launcher
            from vllm.executor.uniproc_executor import (  # noqa
                ExecutorWithExternalLauncher)
            executor_class = ExecutorWithExternalLauncher
        else:
            raise ValueError("unrecognized distributed_executor_backend: "
                             f"{distributed_executor_backend}")
456
457
        return executor_class

458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

474
475
476
477
478
479
480
481
482
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
483
484
485
486
487
488
489
490
491
        vllm_config = engine_args.create_engine_config(usage_context)

        engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            engine_cls = V1LLMEngine

        return engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
492
            usage_context=usage_context,
493
            stat_loggers=stat_loggers,
494
            disable_log_stats=engine_args.disable_log_stats,
yhu422's avatar
yhu422 committed
495
        )
496

497
498
499
500
501
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

502
503
504
505
506
507
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

508
509
    def get_tokenizer_group(self) -> TokenizerGroup:
        if self.tokenizer is None:
510
511
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
512

513
        return self.tokenizer
514

515
    def get_tokenizer(
516
517
518
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
519
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
520

521
    def _init_tokenizer(self) -> TokenizerGroup:
522
523
524
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
525
            lora_config=self.lora_config)
526

527
528
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
529
        self.cache_config.verify_with_parallel_config(self.parallel_config)
530
531
532
533
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
534

535
536
537
    def _add_processed_request(
        self,
        request_id: str,
538
        processed_inputs: ProcessorInputs,
539
540
541
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
542
        trace_headers: Optional[Mapping[str, str]] = None,
543
        priority: int = 0,
544
    ) -> Optional[SequenceGroup]:
545
546
547
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
548
549
550
551
552
553
554
555
556
557
558
559
560
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
            )
            return None

561
        self._validate_model_inputs(processed_inputs, lora_request)
562
563
564
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
565
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
566

567
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
568
569

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
570
                       lora_request)
571

572
        encoder_seq = (None if encoder_inputs is None else Sequence(
573
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
574

575
576
577
578
579
580
581
582
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
583
                trace_headers=trace_headers,
584
585
                encoder_seq=encoder_seq,
                priority=priority)
586
587
588
589
590
591
592
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
593
594
                encoder_seq=encoder_seq,
                priority=priority)
595
596
597
598
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

599
600
601
602
603
604
605
606
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

607
608
        return seq_group

609
610
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
611

612
613
614
    def add_request(
        self,
        request_id: str,
615
        prompt: PromptType,
616
        params: Union[SamplingParams, PoolingParams],
617
        arrival_time: Optional[float] = None,
618
        lora_request: Optional[LoRARequest] = None,
619
        tokenization_kwargs: Optional[dict[str, Any]] = None,
620
        trace_headers: Optional[Mapping[str, str]] = None,
621
        priority: int = 0,
622
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
623
        """Add a request to the engine's request pool.
624
625

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
626
        scheduler as `engine.step()` is called. The exact scheduling policy is
627
628
629
630
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
631
632
            prompt: The prompt to the LLM. See
                [PromptType][vllm.inputs.PromptType]
633
634
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
635
636
                [SamplingParams][vllm.SamplingParams] for text generation.
                [PoolingParams][vllm.PoolingParams] for pooling.
637
            arrival_time: The arrival time of the request. If None, we use
638
                the current monotonic time.
639
            lora_request: The LoRA request to add.
640
            trace_headers: OpenTelemetry trace headers.
641
642
            priority: The priority of the request.
                Only applicable with priority scheduling.
643
644
645
646

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
647
648
649
650
651
            - Create `n` number of [Sequence][vllm.Sequence] objects.
            - Create a [SequenceGroup][vllm.SequenceGroup] object
              from the list of [Sequence][vllm.Sequence].
            - Add the [SequenceGroup][vllm.SequenceGroup] object to the
              scheduler.
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
668
        """
669
670
671
672
        if not isinstance(request_id, str):
            raise TypeError(
                f"request_id must be a string, got {type(request_id)}")

673
674
675
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
676

677
        if priority != 0 and not self.scheduler_config.policy == "priority":
678
679
680
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

681
        if isinstance(params, SamplingParams) \
682
            and params.logits_processors:
683
            raise ValueError(
684
                "Logits processors are not supported in multi-step decoding")
685

686
        if arrival_time is None:
687
            arrival_time = time.time()
688

689
690
691
692
693
694
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            seq_len = prompt["prompt_embeds"].shape[0]
            prompt["prompt_token_ids"] = [0] * seq_len

695
        processed_inputs = self.input_preprocessor.preprocess(
696
            prompt,
697
            tokenization_kwargs=tokenization_kwargs,
698
            lora_request=lora_request,
699
        )
700

701
        self._add_processed_request(
702
703
704
705
706
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
707
            trace_headers=trace_headers,
708
            priority=priority,
709
        )
710
711
712
713
714
715

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
716
717
        arrival_time: float,
        lora_request: Optional[LoRARequest],
718
        trace_headers: Optional[Mapping[str, str]] = None,
719
        encoder_seq: Optional[Sequence] = None,
720
        priority: int = 0,
721
722
723
724
725
726
727
728
729
730
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

731
732
733
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

734
735
736
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
737

738
        sampling_params.update_from_generation_config(
739
            self.generation_config_fields, seq.eos_token_id)
740

741
        # Create the sequence group.
742
743
744
745
        draft_size = 1
        if self.vllm_config.speculative_config is not None:
            draft_size = \
                self.vllm_config.speculative_config.num_speculative_tokens + 1
746
747
748
749
750
751
752
753
754
        seq_group = SequenceGroup(request_id=request_id,
                                  seqs=[seq],
                                  arrival_time=arrival_time,
                                  sampling_params=sampling_params,
                                  lora_request=lora_request,
                                  trace_headers=trace_headers,
                                  encoder_seq=encoder_seq,
                                  priority=priority,
                                  draft_size=draft_size)
755

756
757
758
759
760
761
762
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
763
764
        arrival_time: float,
        lora_request: Optional[LoRARequest],
765
        encoder_seq: Optional[Sequence] = None,
766
        priority: int = 0,
767
768
769
770
771
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
772
773
774
775
776
777
778
        seq_group = SequenceGroup(request_id=request_id,
                                  seqs=[seq],
                                  arrival_time=arrival_time,
                                  lora_request=lora_request,
                                  pooling_params=pooling_params,
                                  encoder_seq=encoder_seq,
                                  priority=priority)
779
        return seq_group
780

Antoni Baum's avatar
Antoni Baum committed
781
782
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
783
784

        Args:
Antoni Baum's avatar
Antoni Baum committed
785
            request_id: The ID(s) of the request to abort.
786
787

        Details:
788
            - Refer to [vllm.core.scheduler.Scheduler.abort_seq_group][].
789
790
791
792
793
794

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
795
        """
796
        for scheduler in self.scheduler:
797
798
            scheduler.abort_seq_group(
                request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
799

800
801
802
803
    def get_vllm_config(self) -> VllmConfig:
        """Gets the vllm configuration."""
        return self.vllm_config

804
805
806
807
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

808
809
810
811
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

812
813
814
815
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

816
817
818
819
820
821
822
823
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

824
    def get_num_unfinished_requests(self) -> int:
825
        """Gets the number of unfinished requests."""
826
827
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
828

829
    def has_unfinished_requests(self) -> bool:
830
        """Returns True if there are unfinished requests."""
831
832
833
834
835
836
837
838
839
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
840

841
842
    def reset_mm_cache(self) -> bool:
        """Reset the multi-modal cache."""
843
844
        return self.input_preprocessor.mm_registry.reset_processor_cache(
            self.model_config)
845

846
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
847
848
849
850
        """Reset prefix cache for all devices."""

        success = True
        for scheduler in self.scheduler:
851
            success = success and scheduler.reset_prefix_cache(device)
852
853
        return success

854
    @staticmethod
855
856
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
857
        outputs: List[PoolingSequenceGroupOutput],
858
    ) -> None:
859
        seq_group.pooled_data = outputs[0].data
860
861
862
863
864
865

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

866
867
868
869
870
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
871

872
873
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
874
        """
875

876
        now = time.time()
877

878
        if len(ctx.output_queue) == 0:
879
880
            return None

881
        # Get pending async postprocessor
882
883
884
885
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
886
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
887
888
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
889
890
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
891
892
893
894
895

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

896
        has_multiple_outputs: bool = len(outputs) > 1
897
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
898
899
        assert not has_multiple_outputs
        outputs_by_sequence_group = outputs
900

901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

918
        finished_before: List[int] = []
919
        finished_now: List[int] = []
920
921
922
923
924
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
925
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
926

927
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
928
929
930
931
932

            if seq_group.is_finished():
                finished_before.append(i)
                continue

933
            output: List[SequenceGroupOutput]
934
            if has_multiple_outputs:
935
936
937
938
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

939
            if not is_async:
940
941
                seq_group.update_num_computed_tokens(
                    seq_group_meta.token_chunk_size or 0)
942
943
944

            if outputs:
                for o in outputs:
945
946
947
948
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
949
                                o.model_forward_time or 0)
950
951
952
953
954
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
955
                                o.model_execute_time or 0)
956
957
958
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
959

960
            if self.model_config.runner_type == "pooling":
961
                self._process_sequence_group_outputs(seq_group, output)
962
963
964
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
965
                    self.output_processor.process_outputs(
966
                        seq_group, output, is_async)
967

968
969
            if seq_group.is_finished():
                finished_now.append(i)
970

971
972
973
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
974

975
976
            seq_group = scheduled_seq_group.seq_group
            seq_group.maybe_set_first_token_time(now)
977
978
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
979
            request_output = RequestOutputFactory.create(
980
981
982
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
983
984
            if request_output:
                ctx.request_outputs.append(request_output)
985

986
987
988
989
990
991
992
993
994
995
996
997
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

998
999
1000
1001
1002
1003
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

        # Create the outputs
1004
1005
        for i in indices:
            if i in skip or i in finished_before or i in finished_now:
1006
1007
                continue  # Avoids double processing

1008
1009
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1010
            seq_group = scheduled_seq_group.seq_group
1011
            seq_group.maybe_set_first_token_time(now)
1012
1013
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1014
            request_output = RequestOutputFactory.create(
1015
1016
1017
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1018
            if request_output:
1019
                ctx.request_outputs.append(request_output)
1020

1021
        # Create outputs only after processing the scheduler's results
1022

1023
        for seq_group in scheduler_outputs.ignored_seq_groups:
1024
1025
1026
1027
1028
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1029
            request_output = RequestOutputFactory.create(
1030
1031
1032
1033
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
1034
1035
            if request_output:
                ctx.request_outputs.append(request_output)
1036

1037
1038
1039
1040
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1041
            ctx.request_outputs.clear()
1042

1043
1044
1045
1046
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1047
            # Log stats.
1048
1049
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1050
1051

            # Tracing
1052
            self.do_tracing(scheduler_outputs, finished_before)
1053
1054
1055
1056

        return None

    def _advance_to_next_step(
1057
            self, output: SamplerOutput,
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1071
1072
1073
1074
            token_chunk_size = (seq_group_metadata.token_chunk_size
                                if seq_group_metadata.token_chunk_size
                                is not None else 0)
            seq_group.update_num_computed_tokens(token_chunk_size)
1075

1076
1077
1078
            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1079
                    " (i.e sampling_params.n == 1)")
1080
1081
1082
1083
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1084

1085
1086
                seq.append_token_id(sample.output_token, sample.logprobs,
                                    sample.output_embed)
1087

1088
    def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1089
1090
        """Performs one decoding iteration and returns newly generated results.

1091
1092
1093
1094
        <figure markdown="span">
        ![Overview of the step function](https://i.imgur.com/sv2HssD.png)
        <figcaption>Overview of the step function</figcaption>
        </figure>
1095
1096

        Details:
1097
1098
        - Step 1: Schedules the sequences to be executed in the next
            iteration and the token blocks to be swapped in/out/copy.
1099

1100
1101
1102
1103
            - Depending on the scheduling policy,
                sequences may be `preempted/reordered`.
            - A Sequence Group (SG) refer to a group of sequences
                that are generated from the same prompt.
1104

1105
1106
        - Step 2: Calls the distributed executor to execute the model.
        - Step 3: Processes the model output. This mainly includes:
1107

1108
1109
1110
1111
            - Decodes the relevant outputs.
            - Updates the scheduled sequence groups with model outputs
                based on its `sampling parameters` (`use_beam_search` or not).
            - Frees the finished sequence groups.
1112

1113
        - Finally, it creates and returns the newly generated results.
1114
1115

        Example:
1116
1117
1118
1119
1120
1121
1122
        ```
        # Please see the example/ folder for more detailed examples.

        # initialize engine and request arguments
        engine = LLMEngine.from_engine_args(engine_args)
        example_inputs = [(0, "What is LLM?",
        SamplingParams(temperature=0.0))]
1123

1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
        # Start the engine with an event loop
        while True:
            if example_inputs:
                req_id, prompt, sampling_params = example_inputs.pop(0)
                engine.add_request(str(req_id),prompt,sampling_params)

            # continue the request processing
            request_outputs = engine.step()
            for request_output in request_outputs:
                if request_output.finished:
                    # return or show the request output

            if not (engine.has_unfinished_requests() or example_inputs):
                break
        ```
Antoni Baum's avatar
Antoni Baum committed
1139
        """
1140
1141
1142
1143
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1144

1145
        # For llm_engine, there is no pipeline parallel support, so the engine
1146
        # used is always 0.
1147
1148
        virtual_engine = 0

1149
1150
        # These are cached outputs from previous iterations. None if on first
        # iteration
1151
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1152
1153
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1154
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1155

1156
1157
        ctx = self.scheduler_contexts[virtual_engine]

1158
1159
1160
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1161
1162
1163
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
1164
1165
1166
1167
1168
        # The scheduler is also skipped if a single request caused the last
        # engine step to fail, and the previous schedule needs to be rerun.
        if not self._has_remaining_steps(
                seq_group_metadata_list
        ) and not self._skip_scheduling_next_step:
1169
            # Schedule iteration
1170
            (seq_group_metadata_list, scheduler_outputs,
1171
1172
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1173

1174
1175
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1176

1177
1178
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
1179
1180
1181
1182
1183
            # When n>1, elements in self.seq_id_to_seq_group should be deleted
            # here, otherwise memory leaks.
            for finished_request_id in finished_requests_ids:
                if finished_request_id in self.seq_id_to_seq_group:
                    del self.seq_id_to_seq_group[finished_request_id]
1184

1185
1186
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1187
                self._process_model_outputs(ctx=ctx)
1188

1189
1190
        else:
            finished_requests_ids = list()
1191
1192
1193

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1194

1195
        if not scheduler_outputs.is_empty():
1196
1197
1198
1199
1200
1201

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1202
                self._get_last_sampled_token_ids(virtual_engine)
1203

1204
            execute_model_req = ExecuteModelRequest(
1205
1206
1207
1208
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1209
1210
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1211
1212
1213
1214
1215
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1216
            if allow_async_output_proc:
1217
1218
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1219

1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
            try:
                outputs = self.model_executor.execute_model(
                    execute_model_req=execute_model_req)
                self._skip_scheduling_next_step = False
            except InputProcessingError as e:
                # The input for this request cannot be processed, so we must
                # abort it. If there are remaining requests in the batch that
                # have been scheduled, they will be retried on the next step.
                invalid_request_id = e.request_id
                self._abort_and_cache_schedule(
                    request_id=invalid_request_id,
                    virtual_engine=virtual_engine,
                    seq_group_metadata_list=seq_group_metadata_list,
                    scheduler_outputs=scheduler_outputs,
                    allow_async_output_proc=allow_async_output_proc)
                # Raise so the caller is notified that this request failed
                raise
1237

1238
        else:
1239
1240
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1241
1242
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1243
            # No outputs in this case
1244
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1245

1246
        if not self._has_remaining_steps(seq_group_metadata_list):
1247
            # is_first_step_output is True only when the num_steps of all
1248
            # the sequences are 1.
1249
1250
1251
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1252
            # Add results to the output_queue
1253
1254
1255
1256
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1257
1258
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1259
1260
1261

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1262
                    "Async postprocessor expects only a single output set")
1263

1264
                self._advance_to_next_step(
1265
                    outputs[0], seq_group_metadata_list,
1266
                    scheduler_outputs.scheduled_seq_groups)
1267

1268
            # Check if need to run the usual non-async path
1269
            if not allow_async_output_proc:
1270
                self._process_model_outputs(ctx=ctx)
1271

1272
                # Log stats.
1273
                self.do_log_stats(scheduler_outputs, outputs)
1274

1275
1276
1277
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1278
            # Multi-step case
1279
            return ctx.request_outputs
1280

1281
        if not self.has_unfinished_requests():
1282
1283
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1284
                self._process_model_outputs(ctx=ctx)
1285
            assert len(ctx.output_queue) == 0
1286

1287
1288
1289
1290
1291
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1292
            logger.debug("Stopping remote worker execution loop.")
1293
1294
            self.model_executor.stop_remote_worker_execution_loop()

1295
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1296

1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
    def _abort_and_cache_schedule(
            self, request_id: str, virtual_engine: int,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        """Aborts a single request, and caches the scheduler outputs minus that
        request. This allows the next step to continue processing the remaining
        requests without having to re-run the scheduler."""

        # Abort the request and remove its sequence group from the current
        # schedule
        self.abort_request(request_id)
        for i, metadata in enumerate(seq_group_metadata_list):
            if metadata.request_id == request_id:
                del seq_group_metadata_list[i]
                break
        for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
            if group.seq_group.request_id == request_id:
                del scheduler_outputs.scheduled_seq_groups[i]
                break

        # If there are still other sequence groups left in the schedule, cache
        # them and flag the engine to reuse the schedule.
        if len(seq_group_metadata_list) > 0:
            self._skip_scheduling_next_step = True
            # Reuse multi-step caching logic
            self._cache_scheduler_outputs_for_multi_step(
                virtual_engine=virtual_engine,
                scheduler_outputs=scheduler_outputs,
                seq_group_metadata_list=seq_group_metadata_list,
                allow_async_output_proc=allow_async_output_proc)

1329
1330
1331
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
1332
        return False
1333
1334
1335
1336

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1337
1338
1339
1340
1341
1342
1343
1344
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        return None

1363
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1364
1365
1366
1367
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1368
1369
1370
1371
1372
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1373
1374
1375
1376
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1377
1378
1379
1380
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1381
1382
1383
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1384
1385
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1386
1387
        """Forced log when no requests active."""
        if self.log_stats:
1388
            stats = self._get_stats(scheduler_outputs, model_output,
1389
                                    finished_before, skip)
1390
            for logger in self.stat_loggers.values():
1391
                logger.log(stats)
1392

1393
1394
1395
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1396
1397
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1398
1399
1400
1401
1402
1403
1404
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1405
1406
1407
1408
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1409
        """
1410
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1411

1412
1413
        # System State
        #   Scheduler State
1414
1415
1416
1417
1418
1419
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1420
1421

        # KV Cache Usage in %
1422
        num_total_gpu = self.cache_config.num_gpu_blocks
1423
        gpu_cache_usage_sys = 0.
1424
        if num_total_gpu:  # Guard against both None and 0
1425
1426
1427
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1428
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1429

1430
        num_total_cpu = self.cache_config.num_cpu_blocks
1431
        cpu_cache_usage_sys = 0.
1432
        if num_total_cpu:  # Guard against both None and 0
1433
1434
1435
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1436
1437
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1438
1439
1440
1441
1442
1443
1444
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
        # Exchange the uasge and cache hit stats between gpu and cpu when
        # running on cpu because the cpu_worker.py intentionally reports the
        # number of cpu blocks as gpu blocks in favor of cache management.
        if self.device_config.device_type == "cpu":
            num_total_gpu, num_total_cpu = num_total_cpu, num_total_gpu
            gpu_cache_usage_sys, cpu_cache_usage_sys = (
                cpu_cache_usage_sys,
                gpu_cache_usage_sys,
            )
            gpu_prefix_cache_hit_rate, cpu_prefix_cache_hit_rate = (
                cpu_prefix_cache_hit_rate,
                gpu_prefix_cache_hit_rate,
            )

1459
1460
1461
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1462
        num_tokens_iter = 0
1463
1464
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1465
1466
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1467
1468
1469
1470

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1471
1472
1473
1474
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1475
1476
1477
1478
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1479
        max_num_generation_tokens_requests: List[int] = []
1480
        max_tokens_requests: List[int] = []
1481
1482
        finished_reason_requests: List[str] = []

1483
        # LoRA requests
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1502
1503
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1504
        if scheduler_outputs is not None:
1505
1506
1507
1508
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1509
            num_generation_tokens_from_prefill_groups = 0
1510
1511
1512
1513
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1514
1515
1516

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1517
1518
1519
1520
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue
1521
1522
1523
1524
1525

                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1526

1527
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1528
                seq_group = scheduled_seq_group.seq_group
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
1541
                        latency = seq_group.get_last_token_latency()
1542
1543
1544
1545
1546
1547
1548
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
1549
                    latency = seq_group.get_last_token_latency()
1550
                    time_per_output_tokens_iter.append(latency)
1551
1552
1553
1554
1555
1556
1557
1558
1559
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1560
1561
1562
1563
1564
1565

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1566
                if seq_group.is_finished():
1567
                    # Latency timings
1568
1569
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1582
1583
1584
1585
1586
1587
1588
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1589
1590
1591
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1592
1593
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1594
1595
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1608
                actual_num_batched_tokens - num_prompt_tokens_iter +
1609
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1610
1611
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1612

1613
1614
        return Stats(
            now=now,
1615
1616
1617
1618
1619
1620
1621
1622
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1623
1624
1625
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1626
1627
1628
1629

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1630
            num_tokens_iter=num_tokens_iter,
1631
1632
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1633
            num_preemption_iter=num_preemption_iter,
1634
1635
1636
1637

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1638
1639
1640
1641
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1642
1643
1644
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1645
1646
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
1647
            n_requests=n_requests,
1648
            max_tokens_requests=max_tokens_requests,
1649
            finished_reason_requests=finished_reason_requests,
1650
1651
1652
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1653

1654
    def add_lora(self, lora_request: LoRARequest) -> bool:
1655
        return self.model_executor.add_lora(lora_request)
1656
1657

    def remove_lora(self, lora_id: int) -> bool:
1658
        return self.model_executor.remove_lora(lora_id)
1659

1660
    def list_loras(self) -> Set[int]:
1661
        return self.model_executor.list_loras()
1662

1663
1664
1665
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1666
1667
1668
1669
1670
1671
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

1672
1673
1674
1675
1676
    def sleep(self, level: int = 1) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.sleep(level=level)

1677
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
1678
1679
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
1680
        self.model_executor.wake_up(tags)
1681

1682
1683
1684
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

1685
    def check_health(self) -> None:
1686
        self.model_executor.check_health()
1687
1688
1689
1690

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

1691
1692
1693
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
1694
1695
1696
        if self.tracer is None:
            return

1697
1698
1699
1700
1701
1702
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
1720
1721
1722
1723
1724
1725
1726
1727

            # Handle potential None values for cancelled/aborted requests
            ttft = (metrics.first_token_time - metrics.arrival_time
                    if metrics.first_token_time is not None else None)

            e2e_time = (metrics.finished_time - metrics.arrival_time
                        if metrics.finished_time is not None else None)

1728
            seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
1729
                                   self.model_config.model)
1730
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
1731
                                   seq_group.request_id)
1732
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
1733
                                   seq_group.sampling_params.temperature)
1734
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
1735
                                   seq_group.sampling_params.top_p)
1736
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
1737
                                   seq_group.sampling_params.max_tokens)
1738
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
1739
                                   seq_group.sampling_params.n)
1740
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
1741
                                   seq_group.num_seqs())
1742
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
1743
1744
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
1745
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
1746
1747
1748
1749
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761

            # Only set timing attributes if the values are available
            if metrics.time_in_queue is not None:
                seq_span.set_attribute(
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
                    metrics.time_in_queue)
            if ttft is not None:
                seq_span.set_attribute(
                    SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            if e2e_time is not None:
                seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E,
                                       e2e_time)
1762
1763
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
1764
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
1765
1766
1767
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
1768
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
1769
1770
1771
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
1772
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
1773
                    metrics.model_execute_time)
1774

1775
    def _validate_model_inputs(self, inputs: ProcessorInputs,
1776
                               lora_request: Optional[LoRARequest]):
1777
1778
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

1779
1780
1781
1782
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs,
                                       lora_request,
                                       prompt_type="encoder")
1783

1784
1785
1786
        self._validate_model_input(decoder_inputs,
                                   lora_request,
                                   prompt_type="decoder")
1787

1788
1789
1790
1791
1792
1793
1794
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        lora_request: Optional[LoRARequest],
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
1795
1796
1797
        model_config = self.model_config
        tokenizer = (None if self.tokenizer is None else
                     self.tokenizer.get_lora_tokenizer(lora_request))
1798

1799
        prompt_ids = prompt_inputs.get("prompt_token_ids", [])
1800
1801
1802
        if not prompt_ids:
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
1803
            elif prompt_inputs["type"] == "embeds":
1804
                pass
1805
1806
1807
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")

1808
1809
1810
1811
1812
1813
        if tokenizer is not None:
            max_input_id = max(prompt_ids, default=0)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    f"Token id {max_input_id} is out of vocabulary")

1814
        max_prompt_len = self.model_config.max_model_len
1815
        if len(prompt_ids) > max_prompt_len:
1816
            if prompt_type == "encoder" and model_config.is_multimodal_model:
1817
1818
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
1819
1820
1821
                    model_config,
                    tokenizer=tokenizer or object(),  # Dummy if no tokenizer
                )
1822
                assert isinstance(mm_processor, EncDecMultiModalProcessor)
1823

1824
                if mm_processor.pad_dummy_encoder_prompt:
汪志鹏's avatar
汪志鹏 committed
1825
                    return  # Skip encoder length check for Whisper and Donut
1826

1827
            if model_config.is_multimodal_model:
1828
                suggestion = (
1829
1830
1831
1832
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
1833
1834
1835
1836
1837
1838
1839
1840
1841
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
1842
1843
1844
1845

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
1846
1847
1848
1849

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
1850
1851
1852
1853
        """Constructs logits processors based on the logits_bias, and
        allowed_token_ids fields in sampling_params. Deletes those fields and
        adds the constructed logits processors to the logits_processors field.
        Returns the modified sampling params."""
1854
1855

        logits_processors = []
1856

1857
1858
1859
        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

1860
            processors = get_openai_logits_processors(
1861
1862
1863
1864
1865
1866
1867
1868
1869
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

1870
1871
1872
1873
1874
1875
        if len(sampling_params.bad_words) > 0:
            tokenizer = self.get_tokenizer(lora_request)
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

1876
1877
1878
1879
1880
1881
1882
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params
1883

1884
1885
1886
1887
1888
1889
1890
1891
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

1892

1893
1894
1895
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
    LLMEngine = V1LLMEngine  # type: ignore