llm_engine.py 91.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import copy
Antoni Baum's avatar
Antoni Baum committed
4
import time
5
from collections import Counter as collectionsCounter
6
from collections import deque
7
from contextlib import contextmanager
8
from dataclasses import dataclass
9
from functools import partial
10
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
11
                    Iterable, List, Literal, Mapping, NamedTuple, Optional)
12
from typing import Sequence as GenericSequence
13
from typing import Set, Type, Union, cast, overload
14

15
import torch
16
from typing_extensions import TypeVar, deprecated
17

18
import vllm.envs as envs
19
20
21
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ObservabilityConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
22
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.engine.arg_utils import EngineArgs
24
from vllm.engine.metrics_types import StatLoggerBase, Stats
25
26
27
28
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
29
30
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
31
from vllm.executor.executor_base import ExecutorBase
32
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
33
from vllm.inputs.parse import split_enc_dec_inputs
34
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.logger import init_logger
36
from vllm.logits_process import get_bad_words_logits_processors
37
from vllm.lora.request import LoRARequest
38
39
from vllm.model_executor.guided_decoding import (
    get_local_guided_decoding_logits_processor)
40
from vllm.model_executor.layers.sampler import SamplerOutput
41
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
42
from vllm.multimodal.processing import EncDecMultiModalProcessor
43
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
44
45
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
46
from vllm.prompt_adapter.request import PromptAdapterRequest
47
from vllm.sampling_params import RequestOutputKind, SamplingParams
48
49
50
51
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
                           PoolingSequenceGroupOutput, Sequence, SequenceGroup,
                           SequenceGroupBase, SequenceGroupMetadata,
                           SequenceGroupOutput, SequenceStatus)
52
53
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
54
from vllm.transformers_utils.detokenizer import Detokenizer
55
from vllm.transformers_utils.tokenizer import AnyTokenizer
56
from vllm.transformers_utils.tokenizer_group import (
57
    TokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
58
59
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
60
61
from vllm.utils import (Counter, Device, deprecate_kwargs,
                        resolve_obj_by_qualname, weak_bind)
62
from vllm.version import __version__ as VLLM_VERSION
63
from vllm.worker.model_runner_base import InputProcessingError
64
65

logger = init_logger(__name__)
66
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
67

68
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
69
_R = TypeVar("_R", default=Any)
70
71


72
73
74
75
76
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
77
78
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
79
80


81
82
83
84
85
86
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
87
88
89
90
91
92
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
93
94
95
    skip: List[int]


96
class SchedulerContext:
97

98
    def __init__(self, multi_step_stream_outputs: bool = False):
99
100
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
101
                                         PoolingRequestOutput]] = []
102
103
104
105
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

106
107
        self.multi_step_stream_outputs: bool = multi_step_stream_outputs

108
109
110
    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
111
112
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
113
114
115
116
117
118
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
119
                       is_first_step_output=is_first_step_output,
120
                       skip=[]))
121
122


123
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
124
    """An LLM engine that receives requests and generates texts.
125

Woosuk Kwon's avatar
Woosuk Kwon committed
126
    This is the main class for the vLLM engine. It receives requests
127
128
129
130
131
132
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

133
134
    The {class}`~vllm.LLM` class wraps this class for offline batched inference
    and the {class}`AsyncLLMEngine` class wraps this class for online serving.
135

136
137
    The config arguments are derived from {class}`~vllm.EngineArgs`. (See
    {ref}`engine-args`)
138
139
140
141
142
143
144

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
145
        device_config: The configuration related to the device.
146
147
148
        lora_config (Optional): The configuration related to serving multi-LoRA.
        speculative_config (Optional): The configuration related to speculative
            decoding.
149
150
        executor_class: The model executor class for managing distributed
            execution.
151
        prompt_adapter_config (Optional): The configuration related to serving
152
            prompt adapters.
153
        log_stats: Whether to log statistics.
154
        usage_context: Specified entry point, used for usage info collection.
155
    """
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

182
        return cast(_O, output)
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

206
    tokenizer: Optional[TokenizerGroup]
207

208
209
    def __init__(
        self,
210
        vllm_config: VllmConfig,
211
        executor_class: Type[ExecutorBase],
212
        log_stats: bool,
yhu422's avatar
yhu422 committed
213
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
214
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
215
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
216
        use_cached_outputs: bool = False,
217
    ) -> None:
218
219
220
221
222
223
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
224

225
        self.vllm_config = vllm_config
226
227
228
229
230
231
232
233
234
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
235
        )
236
237
        self.prompt_adapter_config = vllm_config.prompt_adapter_config  # noqa
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
238
239
        )

240
        logger.info(
241
            "Initializing a V0 LLM engine (v%s) with config: %s, "
242
            "use_cached_outputs=%s, ",
243
            VLLM_VERSION,
244
            vllm_config,
245
            use_cached_outputs,
246
        )
247

248
        self.log_stats = log_stats
249
        self.use_cached_outputs = use_cached_outputs
250

251
        if not self.model_config.skip_tokenizer_init:
252
            self.tokenizer = self._init_tokenizer()
253
            self.detokenizer = Detokenizer(self.tokenizer)
254
            tokenizer_group = self.get_tokenizer_group()
255
256
        else:
            self.tokenizer = None
257
            self.detokenizer = None
258
259
260
261
262
263
264
265
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
266

267
        self.seq_counter = Counter()
268
269
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
270

271
        self.input_preprocessor = InputPreprocessor(self.model_config,
272
273
                                                    self.tokenizer,
                                                    mm_registry)
274

275
        self.model_executor = executor_class(vllm_config=vllm_config)
276

277
        if self.model_config.runner_type != "pooling":
278
            self._initialize_kv_caches()
279

yhu422's avatar
yhu422 committed
280
281
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
282
283
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
284
            usage_message.report_usage(
285
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
286
287
288
289
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
290
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
291
                    "tensor_parallel_size":
292
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
293
                    "block_size":
294
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
295
                    "gpu_memory_utilization":
296
                    self.cache_config.gpu_memory_utilization,
yhu422's avatar
yhu422 committed
297
298
299

                    # Quantization
                    "quantization":
300
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
301
                    "kv_cache_dtype":
302
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
303
304
305

                    # Feature flags
                    "enable_lora":
306
                    bool(self.lora_config),
307
                    "enable_prompt_adapter":
308
                    bool(self.prompt_adapter_config),
yhu422's avatar
yhu422 committed
309
                    "enable_prefix_caching":
310
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
311
                    "enforce_eager":
312
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
313
                    "disable_custom_all_reduce":
314
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
315
316
                })

317
318
319
320
321
322
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
323
324
            SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
                             multi_step_stream_outputs)
325
326
327
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

328
        if self.model_config.use_async_output_proc:
329
330
331
332
333
334
335
336
337
            process_model_outputs = weak_bind(self._process_model_outputs)

            self.async_callbacks = [
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
        else:
            self.async_callbacks = []
338
339
340

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
341
        self.process_request_outputs_callback: Optional[Callable] = None
342

343
        # Create the scheduler.
344
345
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
346
347
348
349
350
        if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
            Scheduler = resolve_obj_by_qualname(
                self.vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = self.vllm_config.scheduler_config.scheduler_cls
351
        self.scheduler = [
352
            Scheduler(
353
354
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
355
                self.async_callbacks[v_id]
356
357
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
358
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
359

360
361
        # Metric Logging.
        if self.log_stats:
362
363
364
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
365
366
367
368
369
370
371
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

372
373
374
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
375
376
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        vllm_config=vllm_config),
377
378
379
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
380
381
                        labels=dict(
                            model_name=self.model_config.served_model_name),
382
                        vllm_config=vllm_config),
383
384
385
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
386

387
388
389
390
391
392
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

393
394
395
396
397
398
399
400
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
401
                get_tokenizer_for_seq,
402
403
                stop_checker=StopChecker(self.scheduler_config.max_model_len,
                                         get_tokenizer_for_seq),
404
405
            ))

406
407
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

408
409
410
411
        # Flag to set when an input fails to process and the engine should run
        # the next step without re-scheduling.
        self._skip_scheduling_next_step = False

412
413
414
415
416
417
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
418
        start = time.time()
419
420
421
422
423
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
424
425
426
427
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
428
429
430
431
432
433
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
434
435
436
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
437

438
    @classmethod
439
    def _get_executor_cls(cls,
440
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
441
        # distributed_executor_backend must be set in VllmConfig.__post_init__
442
443
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
444
        # Initialize the cluster and specify the executor class.
445
446
447
448
449
450
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
451
452
453
454
455
456
457
458
459
460
461
462
463
        elif distributed_executor_backend == "ray":
            from vllm.executor.ray_distributed_executor import (
                RayDistributedExecutor)
            executor_class = RayDistributedExecutor
        elif distributed_executor_backend == "mp":
            from vllm.executor.mp_distributed_executor import (
                MultiprocessingDistributedExecutor)
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
            executor_class = MultiprocessingDistributedExecutor
        elif distributed_executor_backend == "uni":
            # JAX-style, single-process, multi-device executor.
464
465
            from vllm.executor.uniproc_executor import UniProcExecutor
            executor_class = UniProcExecutor
466
467
468
469
470
471
472
473
        elif distributed_executor_backend == "external_launcher":
            # executor with external launcher
            from vllm.executor.uniproc_executor import (  # noqa
                ExecutorWithExternalLauncher)
            executor_class = ExecutorWithExternalLauncher
        else:
            raise ValueError("unrecognized distributed_executor_backend: "
                             f"{distributed_executor_backend}")
474
475
        return executor_class

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

492
493
494
495
496
497
498
499
500
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
501
502
503
504
505
506
507
508
509
        vllm_config = engine_args.create_engine_config(usage_context)

        engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            engine_cls = V1LLMEngine

        return engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
510
            usage_context=usage_context,
511
            stat_loggers=stat_loggers,
512
            disable_log_stats=engine_args.disable_log_stats,
yhu422's avatar
yhu422 committed
513
        )
514

515
516
517
518
519
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

520
521
522
523
524
525
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

526
527
    def get_tokenizer_group(self) -> TokenizerGroup:
        if self.tokenizer is None:
528
529
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
530

531
        return self.tokenizer
532

533
    def get_tokenizer(
534
535
536
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
537
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
538

539
    def _init_tokenizer(self) -> TokenizerGroup:
540
541
542
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
543
            lora_config=self.lora_config)
544

545
546
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
547
        self.cache_config.verify_with_parallel_config(self.parallel_config)
548
549
550
551
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
552
553
554
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
555

556
557
558
    def _add_processed_request(
        self,
        request_id: str,
559
        processed_inputs: ProcessorInputs,
560
561
562
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
563
        prompt_adapter_request: Optional[PromptAdapterRequest],
564
        trace_headers: Optional[Mapping[str, str]] = None,
565
        priority: int = 0,
566
    ) -> Optional[SequenceGroup]:
567
568
569
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
570
571
572
573
574
575
576
577
578
579
580
581
582
583
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
            )
            return None

584
        self._validate_model_inputs(processed_inputs, lora_request)
585
586
587
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
588
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
589

590
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
591
592

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
593
                       lora_request, prompt_adapter_request)
594

595
596
597
        encoder_seq = (None if encoder_inputs is None else Sequence(
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
            prompt_adapter_request))
598

599
600
601
602
603
604
605
606
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
607
                trace_headers=trace_headers,
608
                prompt_adapter_request=prompt_adapter_request,
609
610
                encoder_seq=encoder_seq,
                priority=priority)
611
612
613
614
615
616
617
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
618
                prompt_adapter_request=prompt_adapter_request,
619
620
                encoder_seq=encoder_seq,
                priority=priority)
621
622
623
624
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

625
626
627
628
629
630
631
632
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

633
634
        return seq_group

635
636
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
637

638
    @overload
639
640
641
    def add_request(
        self,
        request_id: str,
642
        prompt: PromptType,
643
        params: Union[SamplingParams, PoolingParams],
644
        arrival_time: Optional[float] = None,
645
        lora_request: Optional[LoRARequest] = None,
646
        tokenization_kwargs: Optional[dict[str, Any]] = None,
647
        trace_headers: Optional[Mapping[str, str]] = None,
648
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
649
        priority: int = 0,
650
    ) -> None:
651
652
653
        ...

    @overload
654
    @deprecated("'inputs' will be renamed to 'prompt")
655
656
657
    def add_request(
        self,
        request_id: str,
658
659
        *,
        inputs: PromptType,
660
661
662
663
664
665
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
666
    ) -> None:
667
668
669
670
671
672
673
674
675
676
677
678
679
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    def add_request(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
680
            tokenization_kwargs: Optional[dict[str, Any]] = None,
681
682
683
684
685
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
            priority: int = 0,
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
686
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
687
        """Add a request to the engine's request pool.
688
689

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
690
        scheduler as `engine.step()` is called. The exact scheduling policy is
691
692
693
694
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
695
            prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType`
696
697
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
698
699
                {class}`~vllm.SamplingParams` for text generation.
                {class}`~vllm.PoolingParams` for pooling.
700
            arrival_time: The arrival time of the request. If None, we use
701
                the current monotonic time.
702
            lora_request: The LoRA request to add.
703
            trace_headers: OpenTelemetry trace headers.
704
            prompt_adapter_request: The prompt adapter request to add.
705
706
            priority: The priority of the request.
                Only applicable with priority scheduling.
707
708
709
710

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
711
712
713
714
            - Create `n` number of {class}`~vllm.Sequence` objects.
            - Create a {class}`~vllm.SequenceGroup` object
              from the list of {class}`~vllm.Sequence`.
            - Add the {class}`~vllm.SequenceGroup` object to the scheduler.
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
731
        """
732
733
734
735
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

736
737
738
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
739

740
        if priority != 0 and not self.scheduler_config.policy == "priority":
741
742
743
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

744
745
746
747
748
749
750
        if isinstance(params, SamplingParams) \
            and (params.guided_decoding or params.logits_processors) \
            and self.scheduler_config.num_scheduler_steps > 1:
            raise ValueError(
                "Guided decoding and logits processors are not supported "
                "in multi-step decoding")

751
        if arrival_time is None:
752
            arrival_time = time.time()
753

754
755
756
757
758
759
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            seq_len = prompt["prompt_embeds"].shape[0]
            prompt["prompt_token_ids"] = [0] * seq_len

760
        processed_inputs = self.input_preprocessor.preprocess(
761
            prompt,
762
            tokenization_kwargs=tokenization_kwargs,
763
            lora_request=lora_request,
764
765
            prompt_adapter_request=prompt_adapter_request,
        )
766

767
        self._add_processed_request(
768
769
770
771
772
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
773
            prompt_adapter_request=prompt_adapter_request,
774
            trace_headers=trace_headers,
775
            priority=priority,
776
        )
777
778
779
780
781
782

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
783
784
        arrival_time: float,
        lora_request: Optional[LoRARequest],
785
        trace_headers: Optional[Mapping[str, str]] = None,
786
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
787
        encoder_seq: Optional[Sequence] = None,
788
        priority: int = 0,
789
790
791
792
793
794
795
796
797
798
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

799
800
801
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

802
803
804
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
805

806
        sampling_params.update_from_generation_config(
807
            self.generation_config_fields, seq.eos_token_id)
808

809
        # Create the sequence group.
810
811
812
813
        draft_size = 1
        if self.vllm_config.speculative_config is not None:
            draft_size = \
                self.vllm_config.speculative_config.num_speculative_tokens + 1
814
815
816
817
818
819
820
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
821
            prompt_adapter_request=prompt_adapter_request,
822
            encoder_seq=encoder_seq,
823
824
            priority=priority,
            draft_size=draft_size)
825

826
827
828
829
830
831
832
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
833
834
        arrival_time: float,
        lora_request: Optional[LoRARequest],
835
        prompt_adapter_request: Optional[PromptAdapterRequest],
836
        encoder_seq: Optional[Sequence] = None,
837
        priority: int = 0,
838
839
840
841
842
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
843
844
845
846
847
848
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
849
            prompt_adapter_request=prompt_adapter_request,
850
851
            encoder_seq=encoder_seq,
            priority=priority)
852
        return seq_group
853

Antoni Baum's avatar
Antoni Baum committed
854
855
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
856
857

        Args:
Antoni Baum's avatar
Antoni Baum committed
858
            request_id: The ID(s) of the request to abort.
859
860
861

        Details:
            - Refer to the
862
863
              {meth}`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class {class}`~vllm.core.scheduler.Scheduler`.
864
865
866
867
868
869

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
870
        """
871
        for scheduler in self.scheduler:
872
873
            scheduler.abort_seq_group(
                request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
874

875
876
877
878
    def get_vllm_config(self) -> VllmConfig:
        """Gets the vllm configuration."""
        return self.vllm_config

879
880
881
882
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

883
884
885
886
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

887
888
889
890
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

891
892
893
894
895
896
897
898
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

899
    def get_num_unfinished_requests(self) -> int:
900
        """Gets the number of unfinished requests."""
901
902
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
903

904
    def has_unfinished_requests(self) -> bool:
905
        """Returns True if there are unfinished requests."""
906
907
908
909
910
911
912
913
914
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
915

916
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
917
918
919
920
        """Reset prefix cache for all devices."""

        success = True
        for scheduler in self.scheduler:
921
            success = success and scheduler.reset_prefix_cache(device)
922
923
        return success

924
    @staticmethod
925
926
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
927
        outputs: List[PoolingSequenceGroupOutput],
928
    ) -> None:
929
        seq_group.pooled_data = outputs[0].data
930
931
932
933
934
935

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

936
937
938
939
940
941
942
943
    def _update_num_computed_tokens_for_multi_step_prefill(
            self, seq_group: SequenceGroup,
            seq_group_meta: SequenceGroupMetadata,
            is_first_step_output: Optional[bool]):
        """
        This function updates num_computed_tokens for prompt sequences
        when Multi-Step is enabled.

944
        seq_group: SequenceGroup to update the num_computed_tokens for.
945
        seq_group_meta: Metadata of the given SequenceGroup.
946
        is_first_step_output: Optional[bool] -
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
            When available, is_first_step_output indicates if the appended
            output token is the output of the first-step in multi-step.
            A value of None indicates that outputs from all steps in
            in multi-step are submitted in a single burst.
        """

        assert self.scheduler_config.is_multi_step

        if not seq_group_meta.is_prompt:
            # num_computed_token updates for multi-step decodes happen after
            # the tokens are appended to the sequence.
            return

        do_update: bool = False
        if self.scheduler_config.chunked_prefill_enabled:
            # In multi-step + chunked-prefill case, the prompt sequences
            # that are scheduled are fully processed in the first step.
            do_update = is_first_step_output is None or is_first_step_output
        else:
            # Normal multi-step decoding case. In this case prompt-sequences
            # are actually single-stepped. Always update in this case.
            assert seq_group.state.num_steps == 1
            do_update = True

        if do_update:
            seq_group.update_num_computed_tokens(
                seq_group_meta.token_chunk_size)

975
976
977
978
979
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
980

981
982
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
983
        """
984

985
        now = time.time()
986

987
        if len(ctx.output_queue) == 0:
988
989
            return None

990
        # Get pending async postprocessor
991
992
993
994
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
995
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
996
997
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
998
999
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
1000
1001
1002
1003
1004

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

1005
        has_multiple_outputs: bool = len(outputs) > 1
1006
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
1007
1008
1009
1010
1011
        if has_multiple_outputs:
            assert self.scheduler_config.is_multi_step or \
                     self.speculative_config
            # Organize outputs by [step][sequence group] instead of
            # [sequence group][step].
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
            if self.scheduler_config.is_multi_step:
                outputs_by_sequence_group = create_output_by_sequence_group(
                    outputs, len(seq_group_metadata_list))
            elif self.speculative_config:
                # Decodes are multi-steps while prefills are not, outputting at
                # most 1 token. Separate them so that we can trigger chunk
                # processing without having to pad or copy over prompts K times
                # to match decodes structure (costly with prompt_logprobs).
                num_prefills = sum(sg.is_prompt
                                   for sg in seq_group_metadata_list)
                prefills, decodes = outputs[:num_prefills], outputs[
                    num_prefills:]
                outputs_by_sequence_group = create_output_by_sequence_group(
                    decodes,
                    num_seq_groups=len(seq_group_metadata_list) - num_prefills)
                outputs_by_sequence_group = [p.outputs for p in prefills
                                             ] + outputs_by_sequence_group
1029
1030
1031
            # We have outputs for multiple steps submitted in a single burst,
            # so invalidate is_first_step_output.
            is_first_step_output = None
1032
1033
1034
        else:
            outputs_by_sequence_group = outputs

1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

1052
        finished_before: List[int] = []
1053
        finished_now: List[int] = []
1054
1055
1056
1057
1058
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
1059
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1060

1061
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
1062
1063
1064
1065
1066

            if seq_group.is_finished():
                finished_before.append(i)
                continue

1067
            output: List[SequenceGroupOutput]
1068
            if has_multiple_outputs:
1069
1070
1071
1072
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

1073
1074
1075
1076
1077
1078
1079
            if not is_async:
                if self.scheduler_config.is_multi_step:
                    # Updates happen only if the sequence is prefill
                    self._update_num_computed_tokens_for_multi_step_prefill(
                        seq_group, seq_group_meta, is_first_step_output)
                else:
                    seq_group.update_num_computed_tokens(
1080
                        seq_group_meta.token_chunk_size or 0)
1081
1082
1083

            if outputs:
                for o in outputs:
1084
1085
1086
1087
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
1088
                                o.model_forward_time or 0)
1089
1090
1091
1092
1093
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
1094
                                o.model_execute_time or 0)
1095
1096
1097
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1098

1099
            if self.model_config.runner_type == "pooling":
1100
                self._process_sequence_group_outputs(seq_group, output)
1101
1102
1103
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
1104
                    self.output_processor.process_outputs(
1105
                        seq_group, output, is_async)
1106

1107
1108
            if seq_group.is_finished():
                finished_now.append(i)
1109

1110
1111
1112
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1113

1114
1115
            seq_group = scheduled_seq_group.seq_group
            seq_group.maybe_set_first_token_time(now)
1116
1117
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1118
            request_output = RequestOutputFactory.create(
1119
1120
1121
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1122
1123
            if request_output:
                ctx.request_outputs.append(request_output)
1124

1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1137
1138
1139
1140
1141
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

1142
1143
        # For multi-step without streaming, don't create outputs each iteration
        if not is_last_step and not ctx.multi_step_stream_outputs:
1144
1145
1146
1147
            # Immediately process request outputs here (if callback is given)
            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
1148
                ctx.request_outputs.clear()
1149
1150
1151
            return

        # Create the outputs
1152
1153
        for i in indices:
            if i in skip or i in finished_before or i in finished_now:
1154
1155
                continue  # Avoids double processing

1156
1157
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1158
            seq_group = scheduled_seq_group.seq_group
1159
            seq_group.maybe_set_first_token_time(now)
1160
1161
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1162
            request_output = RequestOutputFactory.create(
1163
1164
1165
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1166
            if request_output:
1167
                ctx.request_outputs.append(request_output)
1168

1169
1170
1171
1172
1173
1174
1175
1176
        # For multi-step with streaming, create outputs each iteration
        if not is_last_step and ctx.multi_step_stream_outputs:
            # Immediately process request outputs here (if callback is given)
            if self.process_request_outputs_callback is not None:
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1177
        for seq_group in scheduler_outputs.ignored_seq_groups:
1178
1179
1180
1181
1182
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1183
            request_output = RequestOutputFactory.create(
1184
1185
1186
1187
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
1188
1189
            if request_output:
                ctx.request_outputs.append(request_output)
1190

1191
1192
1193
1194
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1195
            ctx.request_outputs.clear()
1196

1197
1198
1199
1200
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1201
            # Log stats.
1202
1203
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1204
1205

            # Tracing
1206
            self.do_tracing(scheduler_outputs, finished_before)
1207
1208
1209
1210

        return None

    def _advance_to_next_step(
1211
            self, output: SamplerOutput,
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1225
1226
1227
1228
1229
            if self.scheduler_config.is_multi_step:
                # Updates happen only if the sequence is prefill
                self._update_num_computed_tokens_for_multi_step_prefill(
                    seq_group, seq_group_metadata,
                    seq_group.state.num_steps == 1)
1230
            else:
1231
1232
1233
1234
                token_chunk_size = (seq_group_metadata.token_chunk_size
                                    if seq_group_metadata.token_chunk_size
                                    is not None else 0)
                seq_group.update_num_computed_tokens(token_chunk_size)
1235

1236
1237
1238
            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1239
                    " (i.e sampling_params.n == 1)")
1240
1241
1242
1243
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1244
1245
1246
1247

                if self.scheduler_config.is_multi_step:
                    is_prefill_append = seq.data.get_num_uncomputed_tokens(
                    ) == 0
1248
1249
                    seq.append_token_id(sample.output_token, sample.logprobs,
                                        sample.output_embed)
1250
1251
1252
                    if not is_prefill_append:
                        seq_group.update_num_computed_tokens(1)
                else:
1253
1254
                    seq.append_token_id(sample.output_token, sample.logprobs,
                                        sample.output_embed)
1255

1256
    def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1257
1258
        """Performs one decoding iteration and returns newly generated results.

1259
1260
1261
        :::{figure} https://i.imgur.com/sv2HssD.png
        :alt: Overview of the step function
        :align: center
1262

1263
1264
        Overview of the step function.
        :::
1265
1266

        Details:
1267
1268
        - Step 1: Schedules the sequences to be executed in the next
            iteration and the token blocks to be swapped in/out/copy.
1269

1270
1271
1272
1273
            - Depending on the scheduling policy,
                sequences may be `preempted/reordered`.
            - A Sequence Group (SG) refer to a group of sequences
                that are generated from the same prompt.
1274

1275
1276
        - Step 2: Calls the distributed executor to execute the model.
        - Step 3: Processes the model output. This mainly includes:
1277

1278
1279
1280
1281
            - Decodes the relevant outputs.
            - Updates the scheduled sequence groups with model outputs
                based on its `sampling parameters` (`use_beam_search` or not).
            - Frees the finished sequence groups.
1282

1283
        - Finally, it creates and returns the newly generated results.
1284
1285

        Example:
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
        ```
        # Please see the example/ folder for more detailed examples.

        # initialize engine and request arguments
        engine = LLMEngine.from_engine_args(engine_args)
        example_inputs = [(0, "What is LLM?",
        SamplingParams(temperature=0.0))]
    
        # Start the engine with an event loop
        while True:
            if example_inputs:
                req_id, prompt, sampling_params = example_inputs.pop(0)
                engine.add_request(str(req_id),prompt,sampling_params)

            # continue the request processing
            request_outputs = engine.step()
            for request_output in request_outputs:
                if request_output.finished:
                    # return or show the request output

            if not (engine.has_unfinished_requests() or example_inputs):
                break
        ```
Antoni Baum's avatar
Antoni Baum committed
1309
        """
1310
1311
1312
1313
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1314

1315
        # For llm_engine, there is no pipeline parallel support, so the engine
1316
        # used is always 0.
1317
1318
        virtual_engine = 0

1319
1320
        # These are cached outputs from previous iterations. None if on first
        # iteration
1321
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1322
1323
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1324
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1325

1326
1327
        ctx = self.scheduler_contexts[virtual_engine]

1328
1329
1330
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1331
1332
1333
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
1334
1335
1336
1337
1338
        # The scheduler is also skipped if a single request caused the last
        # engine step to fail, and the previous schedule needs to be rerun.
        if not self._has_remaining_steps(
                seq_group_metadata_list
        ) and not self._skip_scheduling_next_step:
1339
            # Schedule iteration
1340
            (seq_group_metadata_list, scheduler_outputs,
1341
1342
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1343

1344
1345
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1346

1347
1348
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
1349
1350
1351
1352
1353
            # When n>1, elements in self.seq_id_to_seq_group should be deleted
            # here, otherwise memory leaks.
            for finished_request_id in finished_requests_ids:
                if finished_request_id in self.seq_id_to_seq_group:
                    del self.seq_id_to_seq_group[finished_request_id]
1354

1355
1356
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1357
                self._process_model_outputs(ctx=ctx)
1358

1359
1360
1361
1362
1363
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
1364
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
1365
                    allow_async_output_proc)
1366
1367
        else:
            finished_requests_ids = list()
1368
1369
1370

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1371

1372
        if not scheduler_outputs.is_empty():
1373
1374
1375
1376
1377
1378

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1379
                self._get_last_sampled_token_ids(virtual_engine)
1380

1381
            execute_model_req = ExecuteModelRequest(
1382
1383
1384
1385
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1386
1387
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1388
1389
1390
1391
1392
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1393
            if allow_async_output_proc:
1394
1395
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1396

1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
            try:
                outputs = self.model_executor.execute_model(
                    execute_model_req=execute_model_req)
                self._skip_scheduling_next_step = False
            except InputProcessingError as e:
                # The input for this request cannot be processed, so we must
                # abort it. If there are remaining requests in the batch that
                # have been scheduled, they will be retried on the next step.
                invalid_request_id = e.request_id
                self._abort_and_cache_schedule(
                    request_id=invalid_request_id,
                    virtual_engine=virtual_engine,
                    seq_group_metadata_list=seq_group_metadata_list,
                    scheduler_outputs=scheduler_outputs,
                    allow_async_output_proc=allow_async_output_proc)
                # Raise so the caller is notified that this request failed
                raise
1414

1415
            # We need to do this here so that last step's sampled_token_ids can
1416
1417
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
1418
                self._update_cached_scheduler_output(virtual_engine, outputs)
1419
        else:
1420
1421
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1422
1423
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1424
            # No outputs in this case
1425
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1426

1427
1428
1429
1430
1431
1432
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
1433
            # clear the cache if we have finished all the steps.
1434
1435
1436
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()

1437
1438
1439
1440
1441
1442
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1443
            # Add results to the output_queue
1444
1445
1446
1447
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1448
1449
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1450
1451
1452

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1453
                    "Async postprocessor expects only a single output set")
1454

1455
                self._advance_to_next_step(
1456
                    outputs[0], seq_group_metadata_list,
1457
                    scheduler_outputs.scheduled_seq_groups)
1458

1459
            # Check if need to run the usual non-async path
1460
            if not allow_async_output_proc:
1461
                self._process_model_outputs(ctx=ctx)
1462

1463
                # Log stats.
1464
                self.do_log_stats(scheduler_outputs, outputs)
1465

1466
1467
1468
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1469
            # Multi-step case
1470
            return ctx.request_outputs
1471

1472
        if not self.has_unfinished_requests():
1473
1474
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1475
                self._process_model_outputs(ctx=ctx)
1476
            assert len(ctx.output_queue) == 0
1477

1478
1479
1480
1481
1482
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1483
            logger.debug("Stopping remote worker execution loop.")
1484
1485
            self.model_executor.stop_remote_worker_execution_loop()

1486
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1487

1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
    def _abort_and_cache_schedule(
            self, request_id: str, virtual_engine: int,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        """Aborts a single request, and caches the scheduler outputs minus that
        request. This allows the next step to continue processing the remaining
        requests without having to re-run the scheduler."""

        # Abort the request and remove its sequence group from the current
        # schedule
        self.abort_request(request_id)
        for i, metadata in enumerate(seq_group_metadata_list):
            if metadata.request_id == request_id:
                del seq_group_metadata_list[i]
                break
        for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
            if group.seq_group.request_id == request_id:
                del scheduler_outputs.scheduled_seq_groups[i]
                break

        # If there are still other sequence groups left in the schedule, cache
        # them and flag the engine to reuse the schedule.
        if len(seq_group_metadata_list) > 0:
            self._skip_scheduling_next_step = True
            # Reuse multi-step caching logic
            self._cache_scheduler_outputs_for_multi_step(
                virtual_engine=virtual_engine,
                scheduler_outputs=scheduler_outputs,
                seq_group_metadata_list=seq_group_metadata_list,
                allow_async_output_proc=allow_async_output_proc)

1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
1535
1536
            raise AssertionError("All running sequence groups should "
                                 "have the same remaining steps.")
1537
1538
1539
1540
1541
1542

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1543
1544
1545
1546
1547
1548
1549
1550
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None

1576
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1577
1578
1579
1580
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1581
1582
1583
1584
1585
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1586
1587
1588
1589
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1590
1591
1592
1593
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1594
1595
1596
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1597
1598
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1599
1600
        """Forced log when no requests active."""
        if self.log_stats:
1601
            stats = self._get_stats(scheduler_outputs, model_output,
1602
                                    finished_before, skip)
1603
            for logger in self.stat_loggers.values():
1604
                logger.log(stats)
1605

1606
1607
1608
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1609
1610
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1611
1612
1613
1614
1615
1616
1617
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1618
1619
1620
1621
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1622
        """
1623
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1624

1625
1626
        # System State
        #   Scheduler State
1627
1628
1629
1630
1631
1632
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1633
1634

        # KV Cache Usage in %
1635
        num_total_gpu = self.cache_config.num_gpu_blocks
1636
        gpu_cache_usage_sys = 0.
1637
        if num_total_gpu:  # Guard against both None and 0
1638
1639
1640
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1641
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1642

1643
        num_total_cpu = self.cache_config.num_cpu_blocks
1644
        cpu_cache_usage_sys = 0.
1645
        if num_total_cpu:  # Guard against both None and 0
1646
1647
1648
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1649
1650
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1651
1652
1653
1654
1655
1656
1657
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1658
1659
1660
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1661
        num_tokens_iter = 0
1662
1663
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1664
1665
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1666
1667
1668
1669

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1670
1671
1672
1673
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1674
1675
1676
        time_in_queue_requests: List[float] = []
        model_forward_time_requests: List[float] = []
        model_execute_time_requests: List[float] = []
1677
1678
1679
1680
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1681
        max_num_generation_tokens_requests: List[int] = []
1682
        max_tokens_requests: List[int] = []
1683
1684
        finished_reason_requests: List[str] = []

1685
        # LoRA requests
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1704
1705
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1706
        if scheduler_outputs is not None:
1707
1708
1709
1710
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1711
            num_generation_tokens_from_prefill_groups = 0
1712
1713
1714
1715
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1716
1717
1718

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1719
1720
1721
1722
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue
1723
1724
1725
1726
1727

                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1728

1729
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1730
                seq_group = scheduled_seq_group.seq_group
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
1743
                        latency = seq_group.get_last_token_latency()
1744
1745
1746
1747
1748
1749
1750
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
1751
                    latency = seq_group.get_last_token_latency()
1752
                    time_per_output_tokens_iter.append(latency)
1753
1754
1755
1756
1757
1758
1759
1760
1761
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1762
1763
1764
1765
1766
1767

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1768
                if seq_group.is_finished():
1769
                    # Latency timings
1770
1771
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1784
1785
1786
1787
1788
1789
1790
1791
1792
                    if seq_group.metrics.time_in_queue is not None:
                        time_in_queue_requests.append(
                            seq_group.metrics.time_in_queue)
                    if seq_group.metrics.model_forward_time is not None:
                        model_forward_time_requests.append(
                            seq_group.metrics.model_forward_time)
                    if seq_group.metrics.model_execute_time is not None:
                        model_execute_time_requests.append(
                            seq_group.metrics.model_execute_time * 1000)
1793
1794
1795
1796
1797
1798
1799
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1800
1801
1802
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1803
1804
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1805
1806
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1819
                actual_num_batched_tokens - num_prompt_tokens_iter +
1820
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1821
1822
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1823
1824
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
1825
1826
        if model_output and isinstance(model_output[0], SamplerOutput) and (
                model_output[0].spec_decode_worker_metrics is not None):
1827
1828
1829
1830
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1831
1832
        return Stats(
            now=now,
1833
1834
1835
1836
1837
1838
1839
1840
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1841
1842
1843
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1844
1845
1846
1847

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1848
            num_tokens_iter=num_tokens_iter,
1849
1850
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1851
            spec_decode_metrics=spec_decode_metrics,
1852
            num_preemption_iter=num_preemption_iter,
1853
1854
1855
1856

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1857
1858
1859
1860
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1861
1862
1863
            time_in_queue_requests=time_in_queue_requests,
            model_forward_time_requests=model_forward_time_requests,
            model_execute_time_requests=model_execute_time_requests,
1864
1865
1866
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1867
1868
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
1869
            n_requests=n_requests,
1870
            max_tokens_requests=max_tokens_requests,
1871
            finished_reason_requests=finished_reason_requests,
1872
1873
1874
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1875

1876
    def add_lora(self, lora_request: LoRARequest) -> bool:
1877
        return self.model_executor.add_lora(lora_request)
1878
1879

    def remove_lora(self, lora_id: int) -> bool:
1880
        return self.model_executor.remove_lora(lora_id)
1881

1882
    def list_loras(self) -> Set[int]:
1883
        return self.model_executor.list_loras()
1884

1885
1886
1887
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

1898
1899
1900
1901
1902
1903
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

1904
1905
1906
1907
1908
    def sleep(self, level: int = 1) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.sleep(level=level)

1909
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
1910
1911
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
1912
        self.model_executor.wake_up(tags)
1913

1914
1915
1916
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

1917
    def check_health(self) -> None:
1918
        self.model_executor.check_health()
1919
1920
1921
1922

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

1923
1924
1925
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
1926
1927
1928
        if self.tracer is None:
            return

1929
1930
1931
1932
1933
1934
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
1954
            seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
1955
                                   self.model_config.model)
1956
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
1957
                                   seq_group.request_id)
1958
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
1959
                                   seq_group.sampling_params.temperature)
1960
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
1961
                                   seq_group.sampling_params.top_p)
1962
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
1963
                                   seq_group.sampling_params.max_tokens)
1964
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
1965
                                   seq_group.sampling_params.n)
1966
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
1967
                                   seq_group.num_seqs())
1968
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
1969
1970
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
1971
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
1972
1973
1974
1975
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
1976
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
1977
1978
                                   metrics.time_in_queue)
            seq_span.set_attribute(
1979
1980
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
1981
1982
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
1983
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
1984
1985
1986
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
1987
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
1988
1989
1990
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
1991
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
1992
                    metrics.model_execute_time)
1993

1994
    def _validate_model_inputs(self, inputs: ProcessorInputs,
1995
                               lora_request: Optional[LoRARequest]):
1996
1997
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

1998
1999
2000
2001
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs,
                                       lora_request,
                                       prompt_type="encoder")
2002

2003
2004
2005
        self._validate_model_input(decoder_inputs,
                                   lora_request,
                                   prompt_type="decoder")
2006

2007
2008
2009
2010
2011
2012
2013
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        lora_request: Optional[LoRARequest],
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
2014
2015
2016
        model_config = self.model_config
        tokenizer = (None if self.tokenizer is None else
                     self.tokenizer.get_lora_tokenizer(lora_request))
2017

2018
        prompt_ids = prompt_inputs.get("prompt_token_ids", [])
2019
2020
2021
        if not prompt_ids:
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
2022
            elif prompt_inputs["type"] == "embeds":
2023
                pass
2024
2025
2026
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")

2027
2028
2029
2030
2031
2032
        if tokenizer is not None:
            max_input_id = max(prompt_ids, default=0)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    f"Token id {max_input_id} is out of vocabulary")

2033
        max_prompt_len = self.model_config.max_model_len
2034
        if len(prompt_ids) > max_prompt_len:
2035
            if prompt_type == "encoder" and model_config.is_multimodal_model:
2036
2037
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
2038
2039
2040
                    model_config,
                    tokenizer=tokenizer or object(),  # Dummy if no tokenizer
                )
2041
                assert isinstance(mm_processor, EncDecMultiModalProcessor)
2042

2043
2044
2045
                if mm_processor.pad_dummy_encoder_prompt:
                    return  # Skip encoder length check for Whisper

2046
            if model_config.is_multimodal_model:
2047
                suggestion = (
2048
2049
2050
2051
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
2052
2053
2054
2055
2056
2057
2058
2059
2060
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
2061
2062
2063
2064

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
        """Constructs logits processors based on the guided_decoding,
        logits_bias, and allowed_token_ids fields in sampling_params. Deletes
        those fields and adds the constructed logits processors to the
        logits_processors field. Returns the modified sampling params."""

        logits_processors = []
2075

2076
2077
2078
2079
2080
        if sampling_params.guided_decoding is not None:
            # Defensively copy sampling params since guided decoding logits
            # processors can have different state for each request
            sampling_params = copy.copy(sampling_params)
            guided_decoding = sampling_params.guided_decoding
2081
2082
2083
2084
2085
2086
2087

            logger.debug(
                "Building guided decoding logits processor in "
                "LLMEngine. Params: %s", guided_decoding)

            tokenizer = self.get_tokenizer(lora_request=lora_request)
            guided_decoding.backend = guided_decoding.backend or \
2088
                self.decoding_config.backend
2089

2090
            if self.decoding_config.reasoning_backend:
2091
2092
                logger.debug("Building with reasoning backend %s",
                             self.decoding_config.reasoning_backend)
2093

2094
            processor = get_local_guided_decoding_logits_processor(
2095
2096
                guided_params=guided_decoding,
                tokenizer=tokenizer,
2097
2098
2099
                model_config=self.model_config,
                reasoning_backend=self.decoding_config.reasoning_backend,
            )
2100
2101
2102
2103
2104
2105
2106
2107
2108
            if processor:
                logits_processors.append(processor)

            # Unset so this doesn't get passed down to the model
            sampling_params.guided_decoding = None

        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

2109
            processors = get_openai_logits_processors(
2110
2111
2112
2113
2114
2115
2116
2117
2118
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

2119
2120
2121
2122
2123
2124
        if len(sampling_params.bad_words) > 0:
            tokenizer = self.get_tokenizer(lora_request)
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

2125
2126
2127
2128
2129
2130
2131
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params
2132

2133
2134
2135
2136
2137
2138
2139
2140
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

2141

2142
2143
2144
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
    LLMEngine = V1LLMEngine  # type: ignore