llm_engine.py 93 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import copy
Antoni Baum's avatar
Antoni Baum committed
4
import time
5
from collections import Counter as collectionsCounter
6
from collections import deque
7
from contextlib import contextmanager
8
from dataclasses import dataclass
9
from functools import partial
10
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
11
                    Iterable, List, Literal, Mapping, NamedTuple, Optional)
12
from typing import Sequence as GenericSequence
13
from typing import Set, Type, Union, cast, overload
14

15
import torch
16
from typing_extensions import TypeVar, deprecated
17

18
import vllm.envs as envs
19
20
21
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ObservabilityConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
22
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.engine.arg_utils import EngineArgs
24
from vllm.engine.metrics_types import StatLoggerBase, Stats
25
26
27
28
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
29
30
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
31
from vllm.executor.executor_base import ExecutorBase
32
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
33
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
34
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.logger import init_logger
36
from vllm.logits_process import get_bad_words_logits_processors
37
from vllm.lora.request import LoRARequest
38
39
from vllm.model_executor.guided_decoding import (
    get_local_guided_decoding_logits_processor)
40
from vllm.model_executor.layers.sampler import SamplerOutput
41
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
42
from vllm.multimodal.processing import EncDecMultiModalProcessor
43
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
44
45
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
46
from vllm.prompt_adapter.request import PromptAdapterRequest
47
from vllm.sampling_params import RequestOutputKind, SamplingParams
48
49
50
51
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
                           PoolingSequenceGroupOutput, Sequence, SequenceGroup,
                           SequenceGroupBase, SequenceGroupMetadata,
                           SequenceGroupOutput, SequenceStatus)
52
53
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
54
from vllm.transformers_utils.detokenizer import Detokenizer
55
from vllm.transformers_utils.tokenizer import AnyTokenizer
56
from vllm.transformers_utils.tokenizer_group import (
57
    TokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
58
59
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
60
61
from vllm.utils import (Counter, Device, deprecate_kwargs,
                        resolve_obj_by_qualname, weak_bind)
62
from vllm.version import __version__ as VLLM_VERSION
63
from vllm.worker.model_runner_base import InputProcessingError
64
65

logger = init_logger(__name__)
66
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
67

68
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
69
_R = TypeVar("_R", default=Any)
70
71


72
73
74
75
76
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
77
78
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
79
80


81
82
83
84
85
86
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
87
88
89
90
91
92
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
93
94
95
    skip: List[int]


96
class SchedulerContext:
97

98
    def __init__(self, multi_step_stream_outputs: bool = False):
99
100
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
101
                                         PoolingRequestOutput]] = []
102
103
104
105
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

106
107
        self.multi_step_stream_outputs: bool = multi_step_stream_outputs

108
109
110
    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
111
112
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
113
114
115
116
117
118
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
119
                       is_first_step_output=is_first_step_output,
120
                       skip=[]))
121
122


123
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
124
    """An LLM engine that receives requests and generates texts.
125

Woosuk Kwon's avatar
Woosuk Kwon committed
126
    This is the main class for the vLLM engine. It receives requests
127
128
129
130
131
132
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

133
134
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
135

136
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
137
    :ref:`engine-args`)
138
139
140
141
142
143
144

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
145
        device_config: The configuration related to the device.
146
147
148
        lora_config (Optional): The configuration related to serving multi-LoRA.
        speculative_config (Optional): The configuration related to speculative
            decoding.
149
150
        executor_class: The model executor class for managing distributed
            execution.
151
        prompt_adapter_config (Optional): The configuration related to serving
152
            prompt adapters.
153
        log_stats: Whether to log statistics.
154
        usage_context: Specified entry point, used for usage info collection.
155
    """
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

182
        return cast(_O, output)
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

206
    tokenizer: Optional[TokenizerGroup]
207

208
209
    def __init__(
        self,
210
        vllm_config: VllmConfig,
211
        executor_class: Type[ExecutorBase],
212
        log_stats: bool,
yhu422's avatar
yhu422 committed
213
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
214
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
215
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
216
        use_cached_outputs: bool = False,
217
    ) -> None:
218
219
220
221
222
223
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
224

225
        self.vllm_config = vllm_config
226
227
228
229
230
231
232
233
234
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
235
        )
236
237
        self.prompt_adapter_config = vllm_config.prompt_adapter_config  # noqa
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
238
239
        )

240
        logger.info(
241
            "Initializing a V0 LLM engine (v%s) with config: %s, "
242
            "use_cached_outputs=%s, ",
243
            VLLM_VERSION,
244
            vllm_config,
245
            use_cached_outputs,
246
        )
247

248
        self.log_stats = log_stats
249
        self.use_cached_outputs = use_cached_outputs
250

251
        if not self.model_config.skip_tokenizer_init:
252
            self.tokenizer = self._init_tokenizer()
253
            self.detokenizer = Detokenizer(self.tokenizer)
254
            tokenizer_group = self.get_tokenizer_group()
255
256
        else:
            self.tokenizer = None
257
            self.detokenizer = None
258
259
260
261
262
263
264
265
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
266

267
        self.seq_counter = Counter()
268
269
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
270

271
        self.input_preprocessor = InputPreprocessor(self.model_config,
272
273
                                                    self.tokenizer,
                                                    mm_registry)
274

275
        self.model_executor = executor_class(vllm_config=vllm_config)
276

277
        if self.model_config.runner_type != "pooling":
278
            self._initialize_kv_caches()
279

yhu422's avatar
yhu422 committed
280
281
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
282
283
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
284
            usage_message.report_usage(
285
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
286
287
288
289
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
290
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
291
                    "tensor_parallel_size":
292
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
293
                    "block_size":
294
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
295
                    "gpu_memory_utilization":
296
                    self.cache_config.gpu_memory_utilization,
yhu422's avatar
yhu422 committed
297
298
299

                    # Quantization
                    "quantization":
300
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
301
                    "kv_cache_dtype":
302
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
303
304
305

                    # Feature flags
                    "enable_lora":
306
                    bool(self.lora_config),
307
                    "enable_prompt_adapter":
308
                    bool(self.prompt_adapter_config),
yhu422's avatar
yhu422 committed
309
                    "enable_prefix_caching":
310
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
311
                    "enforce_eager":
312
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
313
                    "disable_custom_all_reduce":
314
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
315
316
                })

317
318
319
320
321
322
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
323
324
            SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
                             multi_step_stream_outputs)
325
326
327
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

328
        if self.model_config.use_async_output_proc:
329
330
331
332
333
334
335
336
337
            process_model_outputs = weak_bind(self._process_model_outputs)

            self.async_callbacks = [
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
        else:
            self.async_callbacks = []
338
339
340

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
341
        self.process_request_outputs_callback: Optional[Callable] = None
342

343
        # Create the scheduler.
344
345
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
346
347
348
349
350
        if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
            Scheduler = resolve_obj_by_qualname(
                self.vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = self.vllm_config.scheduler_config.scheduler_cls
351
        self.scheduler = [
352
            Scheduler(
353
354
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
355
                self.async_callbacks[v_id]
356
357
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
358
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
359

360
361
        # Metric Logging.
        if self.log_stats:
362
363
364
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
365
366
367
368
369
370
371
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

372
373
374
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
375
376
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        vllm_config=vllm_config),
377
378
379
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
380
381
                        labels=dict(
                            model_name=self.model_config.served_model_name),
382
                        vllm_config=vllm_config),
383
384
385
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
386

387
388
389
390
391
392
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

393
394
395
396
397
398
399
400
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
401
                get_tokenizer_for_seq,
402
403
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
404
                    get_tokenizer_for_seq,
405
406
407
                ),
            ))

408
409
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

410
411
412
413
        # Flag to set when an input fails to process and the engine should run
        # the next step without re-scheduling.
        self._skip_scheduling_next_step = False

414
415
416
417
418
419
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
420
        start = time.time()
421
422
423
424
425
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
426
427
428
429
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
430
431
432
433
434
435
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
436
437
438
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
439

440
    @classmethod
441
    def _get_executor_cls(cls,
442
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
443
        # distributed_executor_backend must be set in VllmConfig.__post_init__
444
445
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
446
        # Initialize the cluster and specify the executor class.
447
448
449
450
451
452
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
453
454
455
456
457
458
459
460
461
462
463
464
465
        elif distributed_executor_backend == "ray":
            from vllm.executor.ray_distributed_executor import (
                RayDistributedExecutor)
            executor_class = RayDistributedExecutor
        elif distributed_executor_backend == "mp":
            from vllm.executor.mp_distributed_executor import (
                MultiprocessingDistributedExecutor)
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
            executor_class = MultiprocessingDistributedExecutor
        elif distributed_executor_backend == "uni":
            # JAX-style, single-process, multi-device executor.
466
467
            from vllm.executor.uniproc_executor import UniProcExecutor
            executor_class = UniProcExecutor
468
469
470
471
472
473
474
475
        elif distributed_executor_backend == "external_launcher":
            # executor with external launcher
            from vllm.executor.uniproc_executor import (  # noqa
                ExecutorWithExternalLauncher)
            executor_class = ExecutorWithExternalLauncher
        else:
            raise ValueError("unrecognized distributed_executor_backend: "
                             f"{distributed_executor_backend}")
476
477
        return executor_class

478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

494
495
496
497
498
499
500
501
502
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
503
504
505
506
507
508
509
510
511
        vllm_config = engine_args.create_engine_config(usage_context)

        engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            engine_cls = V1LLMEngine

        return engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
512
            usage_context=usage_context,
513
            stat_loggers=stat_loggers,
514
            disable_log_stats=engine_args.disable_log_stats,
yhu422's avatar
yhu422 committed
515
        )
516

517
518
519
520
521
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

522
523
524
525
526
527
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

528
529
    def get_tokenizer_group(self) -> TokenizerGroup:
        if self.tokenizer is None:
530
531
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
532

533
        return self.tokenizer
534

535
    def get_tokenizer(
536
537
538
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
539
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
540

541
    def _init_tokenizer(self) -> TokenizerGroup:
542
543
544
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
545
            lora_config=self.lora_config)
546

547
548
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
549
        self.cache_config.verify_with_parallel_config(self.parallel_config)
550
551
552
553
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
554
555
556
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
557

558
559
560
    def _add_processed_request(
        self,
        request_id: str,
561
        processed_inputs: ProcessorInputs,
562
563
564
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
565
        prompt_adapter_request: Optional[PromptAdapterRequest],
566
        trace_headers: Optional[Mapping[str, str]] = None,
567
        priority: int = 0,
568
    ) -> Optional[SequenceGroup]:
569
570
571
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
572
573
574
575
576
577
578
579
580
581
582
583
584
585
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
            )
            return None

586
        self._validate_model_inputs(processed_inputs, lora_request)
587
588
589
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
590
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
591

592
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
593
594

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
595
                       lora_request, prompt_adapter_request)
596

597
598
599
        encoder_seq = (None if encoder_inputs is None else Sequence(
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
            prompt_adapter_request))
600

601
602
603
604
605
606
607
608
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
609
                trace_headers=trace_headers,
610
                prompt_adapter_request=prompt_adapter_request,
611
612
                encoder_seq=encoder_seq,
                priority=priority)
613
614
615
616
617
618
619
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
620
                prompt_adapter_request=prompt_adapter_request,
621
622
                encoder_seq=encoder_seq,
                priority=priority)
623
624
625
626
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

627
628
629
630
631
632
633
634
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

635
636
        return seq_group

637
638
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
639

640
    @overload
641
642
643
    def add_request(
        self,
        request_id: str,
644
        prompt: PromptType,
645
        params: Union[SamplingParams, PoolingParams],
646
        arrival_time: Optional[float] = None,
647
        lora_request: Optional[LoRARequest] = None,
648
        tokenization_kwargs: Optional[dict[str, Any]] = None,
649
        trace_headers: Optional[Mapping[str, str]] = None,
650
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
651
        priority: int = 0,
652
    ) -> None:
653
654
655
        ...

    @overload
656
    @deprecated("'inputs' will be renamed to 'prompt")
657
658
659
    def add_request(
        self,
        request_id: str,
660
661
        *,
        inputs: PromptType,
662
663
664
665
666
667
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
668
    ) -> None:
669
670
671
672
673
674
675
676
677
678
679
680
681
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    def add_request(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
682
            tokenization_kwargs: Optional[dict[str, Any]] = None,
683
684
685
686
687
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
            priority: int = 0,
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
688
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
689
        """Add a request to the engine's request pool.
690
691

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
692
        scheduler as `engine.step()` is called. The exact scheduling policy is
693
694
695
696
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
697
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
698
699
700
701
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
702
            arrival_time: The arrival time of the request. If None, we use
703
                the current monotonic time.
704
            lora_request: The LoRA request to add.
705
            trace_headers: OpenTelemetry trace headers.
706
            prompt_adapter_request: The prompt adapter request to add.
707
708
            priority: The priority of the request.
                Only applicable with priority scheduling.
709
710
711
712

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
713
            - Create `n` number of :class:`~vllm.Sequence` objects.
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
733
        """
734
735
736
737
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

738
739
740
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
741

742
        if priority != 0 and not self.scheduler_config.policy == "priority":
743
744
745
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

746
747
748
749
750
751
752
        if isinstance(params, SamplingParams) \
            and (params.guided_decoding or params.logits_processors) \
            and self.scheduler_config.num_scheduler_steps > 1:
            raise ValueError(
                "Guided decoding and logits processors are not supported "
                "in multi-step decoding")

753
        if arrival_time is None:
754
            arrival_time = time.time()
755

756
757
758
759
760
761
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            seq_len = prompt["prompt_embeds"].shape[0]
            prompt["prompt_token_ids"] = [0] * seq_len

762
763
764
765
766
        if self.tokenizer is not None:
            self._validate_token_prompt(
                prompt,
                tokenizer=self.get_tokenizer(lora_request=lora_request))

767
        processed_inputs = self.input_preprocessor.preprocess(
768
            prompt,
769
            tokenization_kwargs=tokenization_kwargs,
770
            lora_request=lora_request,
771
772
            prompt_adapter_request=prompt_adapter_request,
        )
773

774
        self._add_processed_request(
775
776
777
778
779
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
780
            prompt_adapter_request=prompt_adapter_request,
781
            trace_headers=trace_headers,
782
            priority=priority,
783
        )
784

785
786
787
788
789
790
791
792
793
794
795
    def _validate_token_prompt(self, prompt: PromptType,
                               tokenizer: AnyTokenizer):
        # Guard against out-of-vocab tokens.
        # For some tokenizers, tokenizer.decode will happily return empty text
        # for token ids that are out of vocab, and we don't detect token ids
        # that are greater than the max token id before running the model.
        # However, these token ids will later crash a cuda kernel at runtime
        # with an index out of bounds error. This will crash the entire engine.
        # This needs to happen before multimodal input pre-processing, which
        # may add dummy <image> tokens that aren't part of the tokenizer's
        # vocabulary.
796
        if is_token_prompt(prompt):
797
798
799
800
801
802
803
804
805
            prompt_ids = prompt["prompt_token_ids"]
            if len(prompt_ids) == 0:
                # Empty prompt check is handled later
                return
            max_input_id = max(prompt_ids)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    "Token id {} is out of vocabulary".format(max_input_id))

806
807
808
809
810
    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
811
812
        arrival_time: float,
        lora_request: Optional[LoRARequest],
813
        trace_headers: Optional[Mapping[str, str]] = None,
814
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
815
        encoder_seq: Optional[Sequence] = None,
816
        priority: int = 0,
817
818
819
820
821
822
823
824
825
826
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

827
828
829
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

830
831
832
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
833

834
        sampling_params.update_from_generation_config(
835
            self.generation_config_fields, seq.eos_token_id)
836

837
        # Create the sequence group.
838
839
840
841
        draft_size = 1
        if self.vllm_config.speculative_config is not None:
            draft_size = \
                self.vllm_config.speculative_config.num_speculative_tokens + 1
842
843
844
845
846
847
848
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
849
            prompt_adapter_request=prompt_adapter_request,
850
            encoder_seq=encoder_seq,
851
852
            priority=priority,
            draft_size=draft_size)
853

854
855
856
857
858
859
860
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
861
862
        arrival_time: float,
        lora_request: Optional[LoRARequest],
863
        prompt_adapter_request: Optional[PromptAdapterRequest],
864
        encoder_seq: Optional[Sequence] = None,
865
        priority: int = 0,
866
867
868
869
870
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
871
872
873
874
875
876
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
877
            prompt_adapter_request=prompt_adapter_request,
878
879
            encoder_seq=encoder_seq,
            priority=priority)
880
        return seq_group
881

Antoni Baum's avatar
Antoni Baum committed
882
883
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
884
885

        Args:
Antoni Baum's avatar
Antoni Baum committed
886
            request_id: The ID(s) of the request to abort.
887
888
889
890
891
892
893
894
895
896
897

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
898
        """
899
        for scheduler in self.scheduler:
900
901
            scheduler.abort_seq_group(
                request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
902

903
904
905
906
    def get_vllm_config(self) -> VllmConfig:
        """Gets the vllm configuration."""
        return self.vllm_config

907
908
909
910
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

911
912
913
914
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

915
916
917
918
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

919
920
921
922
923
924
925
926
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

927
    def get_num_unfinished_requests(self) -> int:
928
        """Gets the number of unfinished requests."""
929
930
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
931

932
    def has_unfinished_requests(self) -> bool:
933
        """Returns True if there are unfinished requests."""
934
935
936
937
938
939
940
941
942
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
943

944
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
945
946
947
948
        """Reset prefix cache for all devices."""

        success = True
        for scheduler in self.scheduler:
949
            success = success and scheduler.reset_prefix_cache(device)
950
951
        return success

952
    @staticmethod
953
954
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
955
        outputs: List[PoolingSequenceGroupOutput],
956
    ) -> None:
957
        seq_group.pooled_data = outputs[0].data
958
959
960
961
962
963

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

964
965
966
967
968
969
970
971
    def _update_num_computed_tokens_for_multi_step_prefill(
            self, seq_group: SequenceGroup,
            seq_group_meta: SequenceGroupMetadata,
            is_first_step_output: Optional[bool]):
        """
        This function updates num_computed_tokens for prompt sequences
        when Multi-Step is enabled.

972
        seq_group: SequenceGroup to update the num_computed_tokens for.
973
        seq_group_meta: Metadata of the given SequenceGroup.
974
        is_first_step_output: Optional[bool] -
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
            When available, is_first_step_output indicates if the appended
            output token is the output of the first-step in multi-step.
            A value of None indicates that outputs from all steps in
            in multi-step are submitted in a single burst.
        """

        assert self.scheduler_config.is_multi_step

        if not seq_group_meta.is_prompt:
            # num_computed_token updates for multi-step decodes happen after
            # the tokens are appended to the sequence.
            return

        do_update: bool = False
        if self.scheduler_config.chunked_prefill_enabled:
            # In multi-step + chunked-prefill case, the prompt sequences
            # that are scheduled are fully processed in the first step.
            do_update = is_first_step_output is None or is_first_step_output
        else:
            # Normal multi-step decoding case. In this case prompt-sequences
            # are actually single-stepped. Always update in this case.
            assert seq_group.state.num_steps == 1
            do_update = True

        if do_update:
            seq_group.update_num_computed_tokens(
                seq_group_meta.token_chunk_size)

1003
1004
1005
1006
1007
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
1008

1009
1010
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
1011
        """
1012

1013
        now = time.time()
1014

1015
        if len(ctx.output_queue) == 0:
1016
1017
            return None

1018
        # Get pending async postprocessor
1019
1020
1021
1022
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1023
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
1024
1025
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1026
1027
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
1028
1029
1030
1031
1032

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

1033
        has_multiple_outputs: bool = len(outputs) > 1
1034
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
1035
1036
1037
1038
1039
        if has_multiple_outputs:
            assert self.scheduler_config.is_multi_step or \
                     self.speculative_config
            # Organize outputs by [step][sequence group] instead of
            # [sequence group][step].
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
            if self.scheduler_config.is_multi_step:
                outputs_by_sequence_group = create_output_by_sequence_group(
                    outputs, len(seq_group_metadata_list))
            elif self.speculative_config:
                # Decodes are multi-steps while prefills are not, outputting at
                # most 1 token. Separate them so that we can trigger chunk
                # processing without having to pad or copy over prompts K times
                # to match decodes structure (costly with prompt_logprobs).
                num_prefills = sum(sg.is_prompt
                                   for sg in seq_group_metadata_list)
                prefills, decodes = outputs[:num_prefills], outputs[
                    num_prefills:]
                outputs_by_sequence_group = create_output_by_sequence_group(
                    decodes,
                    num_seq_groups=len(seq_group_metadata_list) - num_prefills)
                outputs_by_sequence_group = [p.outputs for p in prefills
                                             ] + outputs_by_sequence_group
1057
1058
1059
            # We have outputs for multiple steps submitted in a single burst,
            # so invalidate is_first_step_output.
            is_first_step_output = None
1060
1061
1062
        else:
            outputs_by_sequence_group = outputs

1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

1080
        finished_before: List[int] = []
1081
        finished_now: List[int] = []
1082
1083
1084
1085
1086
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
1087
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1088

1089
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
1090
1091
1092
1093
1094

            if seq_group.is_finished():
                finished_before.append(i)
                continue

1095
            output: List[SequenceGroupOutput]
1096
            if has_multiple_outputs:
1097
1098
1099
1100
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

1101
1102
1103
1104
1105
1106
1107
            if not is_async:
                if self.scheduler_config.is_multi_step:
                    # Updates happen only if the sequence is prefill
                    self._update_num_computed_tokens_for_multi_step_prefill(
                        seq_group, seq_group_meta, is_first_step_output)
                else:
                    seq_group.update_num_computed_tokens(
1108
                        seq_group_meta.token_chunk_size or 0)
1109
1110
1111

            if outputs:
                for o in outputs:
1112
1113
1114
1115
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
1116
                                o.model_forward_time or 0)
1117
1118
1119
1120
1121
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
1122
                                o.model_execute_time or 0)
1123
1124
1125
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1126

1127
            if self.model_config.runner_type == "pooling":
1128
                self._process_sequence_group_outputs(seq_group, output)
1129
1130
1131
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
1132
                    self.output_processor.process_outputs(
1133
                        seq_group, output, is_async)
1134

1135
1136
            if seq_group.is_finished():
                finished_now.append(i)
1137

1138
1139
1140
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1141

1142
1143
            seq_group = scheduled_seq_group.seq_group
            seq_group.maybe_set_first_token_time(now)
1144
1145
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1146
            request_output = RequestOutputFactory.create(
1147
1148
1149
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1150
1151
            if request_output:
                ctx.request_outputs.append(request_output)
1152

1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1165
1166
1167
1168
1169
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

1170
1171
        # For multi-step without streaming, don't create outputs each iteration
        if not is_last_step and not ctx.multi_step_stream_outputs:
1172
1173
1174
1175
            # Immediately process request outputs here (if callback is given)
            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
1176
                ctx.request_outputs.clear()
1177
1178
1179
            return

        # Create the outputs
1180
1181
        for i in indices:
            if i in skip or i in finished_before or i in finished_now:
1182
1183
                continue  # Avoids double processing

1184
1185
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1186
            seq_group = scheduled_seq_group.seq_group
1187
            seq_group.maybe_set_first_token_time(now)
1188
1189
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1190
            request_output = RequestOutputFactory.create(
1191
1192
1193
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1194
            if request_output:
1195
                ctx.request_outputs.append(request_output)
1196

1197
1198
1199
1200
1201
1202
1203
1204
        # For multi-step with streaming, create outputs each iteration
        if not is_last_step and ctx.multi_step_stream_outputs:
            # Immediately process request outputs here (if callback is given)
            if self.process_request_outputs_callback is not None:
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1205
        for seq_group in scheduler_outputs.ignored_seq_groups:
1206
1207
1208
1209
1210
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1211
            request_output = RequestOutputFactory.create(
1212
1213
1214
1215
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
1216
1217
            if request_output:
                ctx.request_outputs.append(request_output)
1218

1219
1220
1221
1222
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1223
            ctx.request_outputs.clear()
1224

1225
1226
1227
1228
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1229
            # Log stats.
1230
1231
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1232
1233

            # Tracing
1234
            self.do_tracing(scheduler_outputs, finished_before)
1235
1236
1237
1238

        return None

    def _advance_to_next_step(
1239
            self, output: SamplerOutput,
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1253
1254
1255
1256
1257
            if self.scheduler_config.is_multi_step:
                # Updates happen only if the sequence is prefill
                self._update_num_computed_tokens_for_multi_step_prefill(
                    seq_group, seq_group_metadata,
                    seq_group.state.num_steps == 1)
1258
            else:
1259
1260
1261
1262
                token_chunk_size = (seq_group_metadata.token_chunk_size
                                    if seq_group_metadata.token_chunk_size
                                    is not None else 0)
                seq_group.update_num_computed_tokens(token_chunk_size)
1263

1264
1265
1266
            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1267
                    " (i.e sampling_params.n == 1)")
1268
1269
1270
1271
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1272
1273
1274
1275

                if self.scheduler_config.is_multi_step:
                    is_prefill_append = seq.data.get_num_uncomputed_tokens(
                    ) == 0
1276
1277
                    seq.append_token_id(sample.output_token, sample.logprobs,
                                        sample.output_embed)
1278
1279
1280
                    if not is_prefill_append:
                        seq_group.update_num_computed_tokens(1)
                else:
1281
1282
                    seq.append_token_id(sample.output_token, sample.logprobs,
                                        sample.output_embed)
1283

1284
    def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1285
1286
        """Performs one decoding iteration and returns newly generated results.

1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

1302
            - Step 2: Calls the distributed executor to execute the model.
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
1324
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
1325
1326
1327
1328
1329
1330
1331
1332
1333
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
1334
        """
1335
1336
1337
1338
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1339

1340
        # For llm_engine, there is no pipeline parallel support, so the engine
1341
        # used is always 0.
1342
1343
        virtual_engine = 0

1344
1345
        # These are cached outputs from previous iterations. None if on first
        # iteration
1346
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1347
1348
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1349
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1350

1351
1352
        ctx = self.scheduler_contexts[virtual_engine]

1353
1354
1355
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1356
1357
1358
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
1359
1360
1361
1362
1363
        # The scheduler is also skipped if a single request caused the last
        # engine step to fail, and the previous schedule needs to be rerun.
        if not self._has_remaining_steps(
                seq_group_metadata_list
        ) and not self._skip_scheduling_next_step:
1364
            # Schedule iteration
1365
            (seq_group_metadata_list, scheduler_outputs,
1366
1367
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1368

1369
1370
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1371

1372
1373
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
1374
1375
1376
1377
1378
            # When n>1, elements in self.seq_id_to_seq_group should be deleted
            # here, otherwise memory leaks.
            for finished_request_id in finished_requests_ids:
                if finished_request_id in self.seq_id_to_seq_group:
                    del self.seq_id_to_seq_group[finished_request_id]
1379

1380
1381
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1382
                self._process_model_outputs(ctx=ctx)
1383

1384
1385
1386
1387
1388
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
1389
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
1390
                    allow_async_output_proc)
1391
1392
        else:
            finished_requests_ids = list()
1393
1394
1395

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1396

1397
        if not scheduler_outputs.is_empty():
1398
1399
1400
1401
1402
1403

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1404
                self._get_last_sampled_token_ids(virtual_engine)
1405

1406
            execute_model_req = ExecuteModelRequest(
1407
1408
1409
1410
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1411
1412
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1413
1414
1415
1416
1417
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1418
            if allow_async_output_proc:
1419
1420
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1421

1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
            try:
                outputs = self.model_executor.execute_model(
                    execute_model_req=execute_model_req)
                self._skip_scheduling_next_step = False
            except InputProcessingError as e:
                # The input for this request cannot be processed, so we must
                # abort it. If there are remaining requests in the batch that
                # have been scheduled, they will be retried on the next step.
                invalid_request_id = e.request_id
                self._abort_and_cache_schedule(
                    request_id=invalid_request_id,
                    virtual_engine=virtual_engine,
                    seq_group_metadata_list=seq_group_metadata_list,
                    scheduler_outputs=scheduler_outputs,
                    allow_async_output_proc=allow_async_output_proc)
                # Raise so the caller is notified that this request failed
                raise
1439

1440
            # We need to do this here so that last step's sampled_token_ids can
1441
1442
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
1443
                self._update_cached_scheduler_output(virtual_engine, outputs)
1444
        else:
1445
1446
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1447
1448
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1449
            # No outputs in this case
1450
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1451

1452
1453
1454
1455
1456
1457
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
1458
            # clear the cache if we have finished all the steps.
1459
1460
1461
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()

1462
1463
1464
1465
1466
1467
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1468
            # Add results to the output_queue
1469
1470
1471
1472
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1473
1474
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1475
1476
1477

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1478
                    "Async postprocessor expects only a single output set")
1479

1480
                self._advance_to_next_step(
1481
                    outputs[0], seq_group_metadata_list,
1482
                    scheduler_outputs.scheduled_seq_groups)
1483

1484
            # Check if need to run the usual non-async path
1485
            if not allow_async_output_proc:
1486
                self._process_model_outputs(ctx=ctx)
1487

1488
                # Log stats.
1489
                self.do_log_stats(scheduler_outputs, outputs)
1490

1491
1492
1493
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1494
            # Multi-step case
1495
            return ctx.request_outputs
1496

1497
        if not self.has_unfinished_requests():
1498
1499
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1500
                self._process_model_outputs(ctx=ctx)
1501
            assert len(ctx.output_queue) == 0
1502

1503
1504
1505
1506
1507
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1508
            logger.debug("Stopping remote worker execution loop.")
1509
1510
            self.model_executor.stop_remote_worker_execution_loop()

1511
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1512

1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
    def _abort_and_cache_schedule(
            self, request_id: str, virtual_engine: int,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        """Aborts a single request, and caches the scheduler outputs minus that
        request. This allows the next step to continue processing the remaining
        requests without having to re-run the scheduler."""

        # Abort the request and remove its sequence group from the current
        # schedule
        self.abort_request(request_id)
        for i, metadata in enumerate(seq_group_metadata_list):
            if metadata.request_id == request_id:
                del seq_group_metadata_list[i]
                break
        for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
            if group.seq_group.request_id == request_id:
                del scheduler_outputs.scheduled_seq_groups[i]
                break

        # If there are still other sequence groups left in the schedule, cache
        # them and flag the engine to reuse the schedule.
        if len(seq_group_metadata_list) > 0:
            self._skip_scheduling_next_step = True
            # Reuse multi-step caching logic
            self._cache_scheduler_outputs_for_multi_step(
                virtual_engine=virtual_engine,
                scheduler_outputs=scheduler_outputs,
                seq_group_metadata_list=seq_group_metadata_list,
                allow_async_output_proc=allow_async_output_proc)

1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
1560
1561
            raise AssertionError("All running sequence groups should "
                                 "have the same remaining steps.")
1562
1563
1564
1565
1566
1567

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1568
1569
1570
1571
1572
1573
1574
1575
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None

1601
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1602
1603
1604
1605
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1606
1607
1608
1609
1610
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1611
1612
1613
1614
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1615
1616
1617
1618
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1619
1620
1621
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1622
1623
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1624
1625
        """Forced log when no requests active."""
        if self.log_stats:
1626
            stats = self._get_stats(scheduler_outputs, model_output,
1627
                                    finished_before, skip)
1628
            for logger in self.stat_loggers.values():
1629
                logger.log(stats)
1630

1631
1632
1633
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1634
1635
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1636
1637
1638
1639
1640
1641
1642
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1643
1644
1645
1646
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1647
        """
1648
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1649

1650
1651
        # System State
        #   Scheduler State
1652
1653
1654
1655
1656
1657
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1658
1659

        # KV Cache Usage in %
1660
        num_total_gpu = self.cache_config.num_gpu_blocks
1661
        gpu_cache_usage_sys = 0.
1662
        if num_total_gpu:  # Guard against both None and 0
1663
1664
1665
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1666
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1667

1668
        num_total_cpu = self.cache_config.num_cpu_blocks
1669
        cpu_cache_usage_sys = 0.
1670
        if num_total_cpu:  # Guard against both None and 0
1671
1672
1673
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1674
1675
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1676
1677
1678
1679
1680
1681
1682
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1683
1684
1685
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1686
        num_tokens_iter = 0
1687
1688
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1689
1690
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1691
1692
1693
1694

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1695
1696
1697
1698
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1699
1700
1701
        time_in_queue_requests: List[float] = []
        model_forward_time_requests: List[float] = []
        model_execute_time_requests: List[float] = []
1702
1703
1704
1705
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1706
        max_num_generation_tokens_requests: List[int] = []
1707
        max_tokens_requests: List[int] = []
1708
1709
        finished_reason_requests: List[str] = []

1710
        # LoRA requests
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1729
1730
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1731
        if scheduler_outputs is not None:
1732
1733
1734
1735
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1736
            num_generation_tokens_from_prefill_groups = 0
1737
1738
1739
1740
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1741
1742
1743

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1744
1745
1746
1747
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue
1748
1749
1750
1751
1752

                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1753

1754
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1755
                seq_group = scheduled_seq_group.seq_group
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
1768
                        latency = seq_group.get_last_token_latency()
1769
1770
1771
1772
1773
1774
1775
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
1776
                    latency = seq_group.get_last_token_latency()
1777
                    time_per_output_tokens_iter.append(latency)
1778
1779
1780
1781
1782
1783
1784
1785
1786
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1787
1788
1789
1790
1791
1792

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1793
                if seq_group.is_finished():
1794
                    # Latency timings
1795
1796
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1809
1810
1811
1812
1813
1814
1815
1816
1817
                    if seq_group.metrics.time_in_queue is not None:
                        time_in_queue_requests.append(
                            seq_group.metrics.time_in_queue)
                    if seq_group.metrics.model_forward_time is not None:
                        model_forward_time_requests.append(
                            seq_group.metrics.model_forward_time)
                    if seq_group.metrics.model_execute_time is not None:
                        model_execute_time_requests.append(
                            seq_group.metrics.model_execute_time * 1000)
1818
1819
1820
1821
1822
1823
1824
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1825
1826
1827
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1828
1829
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1830
1831
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1844
                actual_num_batched_tokens - num_prompt_tokens_iter +
1845
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1846
1847
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1848
1849
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
1850
1851
        if model_output and isinstance(model_output[0], SamplerOutput) and (
                model_output[0].spec_decode_worker_metrics is not None):
1852
1853
1854
1855
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1856
1857
        return Stats(
            now=now,
1858
1859
1860
1861
1862
1863
1864
1865
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1866
1867
1868
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1869
1870
1871
1872

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1873
            num_tokens_iter=num_tokens_iter,
1874
1875
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1876
            spec_decode_metrics=spec_decode_metrics,
1877
            num_preemption_iter=num_preemption_iter,
1878
1879
1880
1881

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1882
1883
1884
1885
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1886
1887
1888
            time_in_queue_requests=time_in_queue_requests,
            model_forward_time_requests=model_forward_time_requests,
            model_execute_time_requests=model_execute_time_requests,
1889
1890
1891
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1892
1893
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
1894
            n_requests=n_requests,
1895
            max_tokens_requests=max_tokens_requests,
1896
            finished_reason_requests=finished_reason_requests,
1897
1898
1899
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1900

1901
    def add_lora(self, lora_request: LoRARequest) -> bool:
1902
        return self.model_executor.add_lora(lora_request)
1903
1904

    def remove_lora(self, lora_id: int) -> bool:
1905
        return self.model_executor.remove_lora(lora_id)
1906

1907
    def list_loras(self) -> Set[int]:
1908
        return self.model_executor.list_loras()
1909

1910
1911
1912
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

1923
1924
1925
1926
1927
1928
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

1929
1930
1931
1932
1933
    def sleep(self, level: int = 1) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.sleep(level=level)

1934
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
1935
1936
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
1937
        self.model_executor.wake_up(tags)
1938

1939
1940
1941
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

1942
    def check_health(self) -> None:
1943
        self.model_executor.check_health()
1944
1945
1946
1947

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

1948
1949
1950
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
1951
1952
1953
        if self.tracer is None:
            return

1954
1955
1956
1957
1958
1959
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
1979
            seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
1980
                                   self.model_config.model)
1981
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
1982
                                   seq_group.request_id)
1983
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
1984
                                   seq_group.sampling_params.temperature)
1985
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
1986
                                   seq_group.sampling_params.top_p)
1987
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
1988
                                   seq_group.sampling_params.max_tokens)
1989
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
1990
                                   seq_group.sampling_params.n)
1991
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
1992
                                   seq_group.num_seqs())
1993
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
1994
1995
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
1996
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
1997
1998
1999
2000
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
2001
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
2002
2003
                                   metrics.time_in_queue)
            seq_span.set_attribute(
2004
2005
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
2006
2007
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
2008
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
2009
2010
2011
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
2012
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
2013
2014
2015
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
2016
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
2017
                    metrics.model_execute_time)
2018

2019
    def _validate_model_inputs(self, inputs: ProcessorInputs,
2020
                               lora_request: Optional[LoRARequest]):
2021
2022
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

2023
2024
2025
2026
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs,
                                       lora_request,
                                       prompt_type="encoder")
2027

2028
2029
2030
        self._validate_model_input(decoder_inputs,
                                   lora_request,
                                   prompt_type="decoder")
2031

2032
2033
2034
2035
2036
2037
2038
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        lora_request: Optional[LoRARequest],
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
2039
2040
2041
        model_config = self.model_config
        tokenizer = (None if self.tokenizer is None else
                     self.tokenizer.get_lora_tokenizer(lora_request))
2042

2043
        prompt_ids = prompt_inputs.get("prompt_token_ids", [])
2044
2045
2046
        if not prompt_ids:
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
2047
2048
            if prompt_inputs["type"] == "embeds":
                pass
2049
2050
2051
2052
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")

        max_prompt_len = self.model_config.max_model_len
2053
        if len(prompt_ids) > max_prompt_len:
2054
            if prompt_type == "encoder" and model_config.is_multimodal_model:
2055
2056
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
2057
2058
2059
                    model_config,
                    tokenizer=tokenizer or object(),  # Dummy if no tokenizer
                )
2060
                assert isinstance(mm_processor, EncDecMultiModalProcessor)
2061

2062
2063
2064
                if mm_processor.pad_dummy_encoder_prompt:
                    return  # Skip encoder length check for Whisper

2065
            if model_config.is_multimodal_model:
2066
                suggestion = (
2067
2068
2069
2070
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
2071
2072
2073
2074
2075
2076
2077
2078
2079
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
2080
2081
2082
2083

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
        """Constructs logits processors based on the guided_decoding,
        logits_bias, and allowed_token_ids fields in sampling_params. Deletes
        those fields and adds the constructed logits processors to the
        logits_processors field. Returns the modified sampling params."""

        logits_processors = []
2094

2095
2096
2097
2098
2099
        if sampling_params.guided_decoding is not None:
            # Defensively copy sampling params since guided decoding logits
            # processors can have different state for each request
            sampling_params = copy.copy(sampling_params)
            guided_decoding = sampling_params.guided_decoding
2100
2101
2102
2103
2104
2105
2106

            logger.debug(
                "Building guided decoding logits processor in "
                "LLMEngine. Params: %s", guided_decoding)

            tokenizer = self.get_tokenizer(lora_request=lora_request)
            guided_decoding.backend = guided_decoding.backend or \
2107
                self.decoding_config.backend
2108

2109
            if self.decoding_config.reasoning_backend:
2110
2111
                logger.debug("Building with reasoning backend %s",
                             self.decoding_config.reasoning_backend)
2112

2113
            processor = get_local_guided_decoding_logits_processor(
2114
2115
                guided_params=guided_decoding,
                tokenizer=tokenizer,
2116
2117
2118
                model_config=self.model_config,
                reasoning_backend=self.decoding_config.reasoning_backend,
            )
2119
2120
2121
2122
2123
2124
2125
2126
2127
            if processor:
                logits_processors.append(processor)

            # Unset so this doesn't get passed down to the model
            sampling_params.guided_decoding = None

        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

2128
            processors = get_openai_logits_processors(
2129
2130
2131
2132
2133
2134
2135
2136
2137
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

2138
2139
2140
2141
2142
2143
        if len(sampling_params.bad_words) > 0:
            tokenizer = self.get_tokenizer(lora_request)
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

2144
2145
2146
2147
2148
2149
2150
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params
2151

2152
2153
2154
2155
2156
2157
2158
2159
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

2160

2161
2162
2163
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
    LLMEngine = V1LLMEngine  # type: ignore