llm_engine.py 92.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import copy
Antoni Baum's avatar
Antoni Baum committed
4
import time
5
from collections import Counter as collectionsCounter
6
from collections import deque
7
from contextlib import contextmanager
8
from dataclasses import dataclass
9
from functools import partial
10
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
11
                    Iterable, List, Literal, Mapping, NamedTuple, Optional)
12
from typing import Sequence as GenericSequence
13
from typing import Set, Type, Union, cast, overload
14

15
import torch
16
from typing_extensions import TypeVar, deprecated
17

18
import vllm.envs as envs
19
20
21
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ObservabilityConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
22
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.engine.arg_utils import EngineArgs
24
from vllm.engine.metrics_types import StatLoggerBase, Stats
25
26
27
28
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
29
30
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
31
from vllm.executor.executor_base import ExecutorBase
32
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
33
                         PromptType, SingletonInputs)
34
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
35
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.logger import init_logger
37
from vllm.logits_process import get_bad_words_logits_processors
38
from vllm.lora.request import LoRARequest
39
40
from vllm.model_executor.guided_decoding import (
    get_local_guided_decoding_logits_processor)
41
from vllm.model_executor.layers.sampler import SamplerOutput
42
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
43
from vllm.multimodal.processing import EncDecMultiModalProcessor
44
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
45
46
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
47
from vllm.prompt_adapter.request import PromptAdapterRequest
48
from vllm.sampling_params import RequestOutputKind, SamplingParams
49
50
51
52
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
                           PoolingSequenceGroupOutput, Sequence, SequenceGroup,
                           SequenceGroupBase, SequenceGroupMetadata,
                           SequenceGroupOutput, SequenceStatus)
53
54
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
55
from vllm.transformers_utils.detokenizer import Detokenizer
56
from vllm.transformers_utils.tokenizer import AnyTokenizer
57
from vllm.transformers_utils.tokenizer_group import (
58
    TokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
59
60
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
61
62
from vllm.utils import (Counter, Device, deprecate_kwargs,
                        resolve_obj_by_qualname, weak_bind)
63
from vllm.version import __version__ as VLLM_VERSION
64
from vllm.worker.model_runner_base import InputProcessingError
65
66

logger = init_logger(__name__)
67
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
68

69
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
70
_R = TypeVar("_R", default=Any)
71
72


73
74
75
76
77
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
78
79
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
80
81


82
83
84
85
86
87
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
88
89
90
91
92
93
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
94
95
96
    skip: List[int]


97
class SchedulerContext:
98

99
    def __init__(self, multi_step_stream_outputs: bool = False):
100
101
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
102
                                         PoolingRequestOutput]] = []
103
104
105
106
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

107
108
        self.multi_step_stream_outputs: bool = multi_step_stream_outputs

109
110
111
    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
112
113
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
114
115
116
117
118
119
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
120
                       is_first_step_output=is_first_step_output,
121
                       skip=[]))
122
123


124
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
125
    """An LLM engine that receives requests and generates texts.
126

Woosuk Kwon's avatar
Woosuk Kwon committed
127
    This is the main class for the vLLM engine. It receives requests
128
129
130
131
132
133
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

134
135
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
136

137
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
138
    :ref:`engine-args`)
139
140
141
142
143
144
145

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
146
        device_config: The configuration related to the device.
147
148
149
        lora_config (Optional): The configuration related to serving multi-LoRA.
        speculative_config (Optional): The configuration related to speculative
            decoding.
150
151
        executor_class: The model executor class for managing distributed
            execution.
152
        prompt_adapter_config (Optional): The configuration related to serving
153
            prompt adapters.
154
        log_stats: Whether to log statistics.
155
        usage_context: Specified entry point, used for usage info collection.
156
    """
157

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

183
        return cast(_O, output)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

207
    tokenizer: Optional[TokenizerGroup]
208

209
210
    def __init__(
        self,
211
        vllm_config: VllmConfig,
212
        executor_class: Type[ExecutorBase],
213
        log_stats: bool,
yhu422's avatar
yhu422 committed
214
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
215
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
216
        input_registry: InputRegistry = INPUT_REGISTRY,
217
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
218
        use_cached_outputs: bool = False,
219
    ) -> None:
220
221
222
223
224
225
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
226

227
        self.vllm_config = vllm_config
228
229
230
231
232
233
234
235
236
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
237
        )
238
239
        self.prompt_adapter_config = vllm_config.prompt_adapter_config  # noqa
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
240
241
        )

242
        logger.info(
243
            "Initializing a V0 LLM engine (v%s) with config: %s, "
244
            "use_cached_outputs=%s, ",
245
            VLLM_VERSION,
246
            vllm_config,
247
            use_cached_outputs,
248
        )
249

250
        self.log_stats = log_stats
251
        self.use_cached_outputs = use_cached_outputs
252

253
        if not self.model_config.skip_tokenizer_init:
254
            self.tokenizer = self._init_tokenizer()
255
            self.detokenizer = Detokenizer(self.tokenizer)
256
            tokenizer_group = self.get_tokenizer_group()
257
258
        else:
            self.tokenizer = None
259
            self.detokenizer = None
260
261
262
263
264
265
266
267
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
268

269
        self.seq_counter = Counter()
270
271
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
272

273
        self.input_preprocessor = InputPreprocessor(self.model_config,
274
275
                                                    self.tokenizer,
                                                    mm_registry)
276

277
278
        self.input_registry = input_registry
        self.input_processor = input_registry.create_input_processor(
279
            self.model_config)
280

281
        self.model_executor = executor_class(vllm_config=vllm_config, )
282

283
        if self.model_config.runner_type != "pooling":
284
            self._initialize_kv_caches()
285

yhu422's avatar
yhu422 committed
286
287
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
288
289
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
290
            usage_message.report_usage(
291
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
292
293
294
295
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
296
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
297
                    "tensor_parallel_size":
298
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
299
                    "block_size":
300
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
301
                    "gpu_memory_utilization":
302
                    self.cache_config.gpu_memory_utilization,
yhu422's avatar
yhu422 committed
303
304
305

                    # Quantization
                    "quantization":
306
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
307
                    "kv_cache_dtype":
308
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
309
310
311

                    # Feature flags
                    "enable_lora":
312
                    bool(self.lora_config),
313
                    "enable_prompt_adapter":
314
                    bool(self.prompt_adapter_config),
yhu422's avatar
yhu422 committed
315
                    "enable_prefix_caching":
316
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
317
                    "enforce_eager":
318
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
319
                    "disable_custom_all_reduce":
320
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
321
322
                })

323
324
325
326
327
328
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
329
330
            SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
                             multi_step_stream_outputs)
331
332
333
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

334
        if self.model_config.use_async_output_proc:
335
336
337
338
339
340
341
342
343
            process_model_outputs = weak_bind(self._process_model_outputs)

            self.async_callbacks = [
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
        else:
            self.async_callbacks = []
344
345
346

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
347
        self.process_request_outputs_callback: Optional[Callable] = None
348

349
        # Create the scheduler.
350
351
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
352
353
354
355
356
        if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
            Scheduler = resolve_obj_by_qualname(
                self.vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = self.vllm_config.scheduler_config.scheduler_cls
357
        self.scheduler = [
358
            Scheduler(
359
360
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
361
                self.async_callbacks[v_id]
362
363
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
364
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
365

366
367
        # Metric Logging.
        if self.log_stats:
368
369
370
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
371
372
373
374
375
376
377
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

378
379
380
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
381
382
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        vllm_config=vllm_config),
383
384
385
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
386
387
                        labels=dict(
                            model_name=self.model_config.served_model_name),
388
                        vllm_config=vllm_config),
389
390
391
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
392

393
394
395
396
397
398
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

399
400
401
402
403
404
405
406
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
407
                get_tokenizer_for_seq,
408
409
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
410
                    get_tokenizer_for_seq,
411
412
413
                ),
            ))

414
415
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

416
417
418
419
        # Flag to set when an input fails to process and the engine should run
        # the next step without re-scheduling.
        self._skip_scheduling_next_step = False

420
421
422
423
424
425
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
426
        start = time.time()
427
428
429
430
431
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
432
433
434
435
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
436
437
438
439
440
441
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
442
443
444
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
445

446
    @classmethod
447
    def _get_executor_cls(cls,
448
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
449
        # distributed_executor_backend must be set in VllmConfig.__post_init__
450
451
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
452
        # Initialize the cluster and specify the executor class.
453
454
455
456
457
458
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
459
460
461
462
463
464
465
466
467
468
469
470
471
        elif distributed_executor_backend == "ray":
            from vllm.executor.ray_distributed_executor import (
                RayDistributedExecutor)
            executor_class = RayDistributedExecutor
        elif distributed_executor_backend == "mp":
            from vllm.executor.mp_distributed_executor import (
                MultiprocessingDistributedExecutor)
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
            executor_class = MultiprocessingDistributedExecutor
        elif distributed_executor_backend == "uni":
            # JAX-style, single-process, multi-device executor.
472
473
            from vllm.executor.uniproc_executor import UniProcExecutor
            executor_class = UniProcExecutor
474
475
476
477
478
479
480
481
        elif distributed_executor_backend == "external_launcher":
            # executor with external launcher
            from vllm.executor.uniproc_executor import (  # noqa
                ExecutorWithExternalLauncher)
            executor_class = ExecutorWithExternalLauncher
        else:
            raise ValueError("unrecognized distributed_executor_backend: "
                             f"{distributed_executor_backend}")
482
483
        return executor_class

484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

500
501
502
503
504
505
506
507
508
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
509
510
511
512
513
514
515
516
517
        vllm_config = engine_args.create_engine_config(usage_context)

        engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            engine_cls = V1LLMEngine

        return engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
518
            usage_context=usage_context,
519
            stat_loggers=stat_loggers,
520
            disable_log_stats=engine_args.disable_log_stats,
yhu422's avatar
yhu422 committed
521
        )
522

523
524
525
526
527
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

528
529
530
531
532
533
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

534
535
    def get_tokenizer_group(self) -> TokenizerGroup:
        if self.tokenizer is None:
536
537
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
538

539
        return self.tokenizer
540

541
    def get_tokenizer(
542
543
544
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
545
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
546

547
    def _init_tokenizer(self) -> TokenizerGroup:
548
549
550
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
551
            lora_config=self.lora_config)
552

553
554
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
555
        self.cache_config.verify_with_parallel_config(self.parallel_config)
556
557
558
559
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
560
561
562
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
563

564
565
566
    def _add_processed_request(
        self,
        request_id: str,
567
        processed_inputs: ProcessorInputs,
568
569
570
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
571
        prompt_adapter_request: Optional[PromptAdapterRequest],
572
        trace_headers: Optional[Mapping[str, str]] = None,
573
        priority: int = 0,
574
    ) -> Optional[SequenceGroup]:
575
576
577
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
            )
            return None

592
        self._validate_model_inputs(processed_inputs, lora_request)
593
594
595
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
596
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
597

598
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
599
600

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
601
                       lora_request, prompt_adapter_request)
602

603
604
605
        encoder_seq = (None if encoder_inputs is None else Sequence(
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
            prompt_adapter_request))
606

607
608
609
610
611
612
613
614
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
615
                trace_headers=trace_headers,
616
                prompt_adapter_request=prompt_adapter_request,
617
618
                encoder_seq=encoder_seq,
                priority=priority)
619
620
621
622
623
624
625
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
626
                prompt_adapter_request=prompt_adapter_request,
627
628
                encoder_seq=encoder_seq,
                priority=priority)
629
630
631
632
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

633
634
635
636
637
638
639
640
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

641
642
        return seq_group

643
644
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
645

646
    @overload
647
648
649
    def add_request(
        self,
        request_id: str,
650
        prompt: PromptType,
651
        params: Union[SamplingParams, PoolingParams],
652
        arrival_time: Optional[float] = None,
653
        lora_request: Optional[LoRARequest] = None,
654
        trace_headers: Optional[Mapping[str, str]] = None,
655
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
656
        priority: int = 0,
657
    ) -> None:
658
659
660
        ...

    @overload
661
    @deprecated("'inputs' will be renamed to 'prompt")
662
663
664
    def add_request(
        self,
        request_id: str,
665
666
        *,
        inputs: PromptType,
667
668
669
670
671
672
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
673
    ) -> None:
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    def add_request(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
            priority: int = 0,
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
692
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
693
        """Add a request to the engine's request pool.
694
695

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
696
        scheduler as `engine.step()` is called. The exact scheduling policy is
697
698
699
700
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
701
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
702
703
704
705
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
706
            arrival_time: The arrival time of the request. If None, we use
707
                the current monotonic time.
708
            lora_request: The LoRA request to add.
709
            trace_headers: OpenTelemetry trace headers.
710
            prompt_adapter_request: The prompt adapter request to add.
711
712
            priority: The priority of the request.
                Only applicable with priority scheduling.
713
714
715
716

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
717
            - Create `n` number of :class:`~vllm.Sequence` objects.
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
737
        """
738
739
740
741
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

742
743
744
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
745

746
        if priority != 0 and not self.scheduler_config.policy == "priority":
747
748
749
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

750
751
752
753
754
755
756
        if isinstance(params, SamplingParams) \
            and (params.guided_decoding or params.logits_processors) \
            and self.scheduler_config.num_scheduler_steps > 1:
            raise ValueError(
                "Guided decoding and logits processors are not supported "
                "in multi-step decoding")

757
        if arrival_time is None:
758
            arrival_time = time.time()
759

760
761
762
763
764
        if self.tokenizer is not None:
            self._validate_token_prompt(
                prompt,
                tokenizer=self.get_tokenizer(lora_request=lora_request))

765
        preprocessed_inputs = self.input_preprocessor.preprocess(
766
            prompt,
767
            lora_request=lora_request,
768
769
            prompt_adapter_request=prompt_adapter_request,
        )
770
        processed_inputs = self.input_processor(preprocessed_inputs)
771

772
        self._add_processed_request(
773
774
775
776
777
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
778
            prompt_adapter_request=prompt_adapter_request,
779
            trace_headers=trace_headers,
780
            priority=priority,
781
        )
782

783
784
785
786
787
788
789
790
791
792
793
    def _validate_token_prompt(self, prompt: PromptType,
                               tokenizer: AnyTokenizer):
        # Guard against out-of-vocab tokens.
        # For some tokenizers, tokenizer.decode will happily return empty text
        # for token ids that are out of vocab, and we don't detect token ids
        # that are greater than the max token id before running the model.
        # However, these token ids will later crash a cuda kernel at runtime
        # with an index out of bounds error. This will crash the entire engine.
        # This needs to happen before multimodal input pre-processing, which
        # may add dummy <image> tokens that aren't part of the tokenizer's
        # vocabulary.
794
        if is_token_prompt(prompt):
795
796
797
798
799
800
801
802
803
            prompt_ids = prompt["prompt_token_ids"]
            if len(prompt_ids) == 0:
                # Empty prompt check is handled later
                return
            max_input_id = max(prompt_ids)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    "Token id {} is out of vocabulary".format(max_input_id))

804
805
806
807
808
    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
809
810
        arrival_time: float,
        lora_request: Optional[LoRARequest],
811
        trace_headers: Optional[Mapping[str, str]] = None,
812
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
813
        encoder_seq: Optional[Sequence] = None,
814
        priority: int = 0,
815
816
817
818
819
820
821
822
823
824
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

825
826
827
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

828
829
830
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
831

832
        sampling_params.update_from_generation_config(
833
            self.generation_config_fields, seq.eos_token_id)
834

835
        # Create the sequence group.
836
837
838
839
        draft_size = 1
        if self.vllm_config.speculative_config is not None:
            draft_size = \
                self.vllm_config.speculative_config.num_speculative_tokens + 1
840
841
842
843
844
845
846
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
847
            prompt_adapter_request=prompt_adapter_request,
848
            encoder_seq=encoder_seq,
849
850
            priority=priority,
            draft_size=draft_size)
851

852
853
854
855
856
857
858
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
859
860
        arrival_time: float,
        lora_request: Optional[LoRARequest],
861
        prompt_adapter_request: Optional[PromptAdapterRequest],
862
        encoder_seq: Optional[Sequence] = None,
863
        priority: int = 0,
864
865
866
867
868
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
869
870
871
872
873
874
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
875
            prompt_adapter_request=prompt_adapter_request,
876
877
            encoder_seq=encoder_seq,
            priority=priority)
878
        return seq_group
879

Antoni Baum's avatar
Antoni Baum committed
880
881
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
882
883

        Args:
Antoni Baum's avatar
Antoni Baum committed
884
            request_id: The ID(s) of the request to abort.
885
886
887
888
889
890
891
892
893
894
895

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
896
        """
897
        for scheduler in self.scheduler:
898
899
            scheduler.abort_seq_group(
                request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
900

901
902
903
904
    def get_vllm_config(self) -> VllmConfig:
        """Gets the vllm configuration."""
        return self.vllm_config

905
906
907
908
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

909
910
911
912
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

913
914
915
916
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

917
918
919
920
921
922
923
924
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

925
    def get_num_unfinished_requests(self) -> int:
926
        """Gets the number of unfinished requests."""
927
928
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
929

930
    def has_unfinished_requests(self) -> bool:
931
        """Returns True if there are unfinished requests."""
932
933
934
935
936
937
938
939
940
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
941

942
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
943
944
945
946
        """Reset prefix cache for all devices."""

        success = True
        for scheduler in self.scheduler:
947
            success = success and scheduler.reset_prefix_cache(device)
948
949
        return success

950
    @staticmethod
951
952
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
953
        outputs: List[PoolingSequenceGroupOutput],
954
    ) -> None:
955
        seq_group.pooled_data = outputs[0].data
956
957
958
959
960
961

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

962
963
964
965
966
967
968
969
    def _update_num_computed_tokens_for_multi_step_prefill(
            self, seq_group: SequenceGroup,
            seq_group_meta: SequenceGroupMetadata,
            is_first_step_output: Optional[bool]):
        """
        This function updates num_computed_tokens for prompt sequences
        when Multi-Step is enabled.

970
        seq_group: SequenceGroup to update the num_computed_tokens for.
971
        seq_group_meta: Metadata of the given SequenceGroup.
972
        is_first_step_output: Optional[bool] -
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
            When available, is_first_step_output indicates if the appended
            output token is the output of the first-step in multi-step.
            A value of None indicates that outputs from all steps in
            in multi-step are submitted in a single burst.
        """

        assert self.scheduler_config.is_multi_step

        if not seq_group_meta.is_prompt:
            # num_computed_token updates for multi-step decodes happen after
            # the tokens are appended to the sequence.
            return

        do_update: bool = False
        if self.scheduler_config.chunked_prefill_enabled:
            # In multi-step + chunked-prefill case, the prompt sequences
            # that are scheduled are fully processed in the first step.
            do_update = is_first_step_output is None or is_first_step_output
        else:
            # Normal multi-step decoding case. In this case prompt-sequences
            # are actually single-stepped. Always update in this case.
            assert seq_group.state.num_steps == 1
            do_update = True

        if do_update:
            seq_group.update_num_computed_tokens(
                seq_group_meta.token_chunk_size)

1001
1002
1003
1004
1005
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
1006

1007
1008
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
1009
        """
1010

1011
        now = time.time()
1012

1013
        if len(ctx.output_queue) == 0:
1014
1015
            return None

1016
        # Get pending async postprocessor
1017
1018
1019
1020
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1021
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
1022
1023
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1024
1025
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
1026
1027
1028
1029
1030

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

1031
        has_multiple_outputs: bool = len(outputs) > 1
1032
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
1033
1034
1035
1036
1037
        if has_multiple_outputs:
            assert self.scheduler_config.is_multi_step or \
                     self.speculative_config
            # Organize outputs by [step][sequence group] instead of
            # [sequence group][step].
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
            if self.scheduler_config.is_multi_step:
                outputs_by_sequence_group = create_output_by_sequence_group(
                    outputs, len(seq_group_metadata_list))
            elif self.speculative_config:
                # Decodes are multi-steps while prefills are not, outputting at
                # most 1 token. Separate them so that we can trigger chunk
                # processing without having to pad or copy over prompts K times
                # to match decodes structure (costly with prompt_logprobs).
                num_prefills = sum(sg.is_prompt
                                   for sg in seq_group_metadata_list)
                prefills, decodes = outputs[:num_prefills], outputs[
                    num_prefills:]
                outputs_by_sequence_group = create_output_by_sequence_group(
                    decodes,
                    num_seq_groups=len(seq_group_metadata_list) - num_prefills)
                outputs_by_sequence_group = [p.outputs for p in prefills
                                             ] + outputs_by_sequence_group
1055
1056
1057
            # We have outputs for multiple steps submitted in a single burst,
            # so invalidate is_first_step_output.
            is_first_step_output = None
1058
1059
1060
        else:
            outputs_by_sequence_group = outputs

1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

1078
        finished_before: List[int] = []
1079
        finished_now: List[int] = []
1080
1081
1082
1083
1084
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
1085
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1086

1087
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
1088
1089
1090
1091
1092

            if seq_group.is_finished():
                finished_before.append(i)
                continue

1093
            output: List[SequenceGroupOutput]
1094
            if has_multiple_outputs:
1095
1096
1097
1098
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

1099
1100
1101
1102
1103
1104
1105
            if not is_async:
                if self.scheduler_config.is_multi_step:
                    # Updates happen only if the sequence is prefill
                    self._update_num_computed_tokens_for_multi_step_prefill(
                        seq_group, seq_group_meta, is_first_step_output)
                else:
                    seq_group.update_num_computed_tokens(
1106
                        seq_group_meta.token_chunk_size or 0)
1107
1108
1109

            if outputs:
                for o in outputs:
1110
1111
1112
1113
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
1114
                                o.model_forward_time or 0)
1115
1116
1117
1118
1119
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
1120
                                o.model_execute_time or 0)
1121
1122
1123
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1124

1125
            if self.model_config.runner_type == "pooling":
1126
                self._process_sequence_group_outputs(seq_group, output)
1127
1128
1129
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
1130
                    self.output_processor.process_outputs(
1131
                        seq_group, output, is_async)
1132

1133
1134
            if seq_group.is_finished():
                finished_now.append(i)
1135

1136
1137
1138
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1139

1140
1141
            seq_group = scheduled_seq_group.seq_group
            seq_group.maybe_set_first_token_time(now)
1142
1143
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1144
            request_output = RequestOutputFactory.create(
1145
1146
1147
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1148
1149
            if request_output:
                ctx.request_outputs.append(request_output)
1150

1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1163
1164
1165
1166
1167
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

1168
1169
        # For multi-step without streaming, don't create outputs each iteration
        if not is_last_step and not ctx.multi_step_stream_outputs:
1170
1171
1172
1173
            # Immediately process request outputs here (if callback is given)
            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
1174
                ctx.request_outputs.clear()
1175
1176
1177
            return

        # Create the outputs
1178
1179
        for i in indices:
            if i in skip or i in finished_before or i in finished_now:
1180
1181
                continue  # Avoids double processing

1182
1183
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1184
            seq_group = scheduled_seq_group.seq_group
1185
            seq_group.maybe_set_first_token_time(now)
1186
1187
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1188
            request_output = RequestOutputFactory.create(
1189
1190
1191
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1192
            if request_output:
1193
                ctx.request_outputs.append(request_output)
1194

1195
1196
1197
1198
1199
1200
1201
1202
        # For multi-step with streaming, create outputs each iteration
        if not is_last_step and ctx.multi_step_stream_outputs:
            # Immediately process request outputs here (if callback is given)
            if self.process_request_outputs_callback is not None:
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1203
        for seq_group in scheduler_outputs.ignored_seq_groups:
1204
1205
1206
1207
1208
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1209
            request_output = RequestOutputFactory.create(
1210
1211
1212
1213
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
1214
1215
            if request_output:
                ctx.request_outputs.append(request_output)
1216

1217
1218
1219
1220
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1221
            ctx.request_outputs.clear()
1222

1223
1224
1225
1226
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1227
            # Log stats.
1228
1229
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1230
1231

            # Tracing
1232
            self.do_tracing(scheduler_outputs, finished_before)
1233
1234
1235
1236

        return None

    def _advance_to_next_step(
1237
            self, output: SamplerOutput,
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1251
1252
1253
1254
1255
            if self.scheduler_config.is_multi_step:
                # Updates happen only if the sequence is prefill
                self._update_num_computed_tokens_for_multi_step_prefill(
                    seq_group, seq_group_metadata,
                    seq_group.state.num_steps == 1)
1256
            else:
1257
1258
1259
1260
                token_chunk_size = (seq_group_metadata.token_chunk_size
                                    if seq_group_metadata.token_chunk_size
                                    is not None else 0)
                seq_group.update_num_computed_tokens(token_chunk_size)
1261

1262
1263
1264
            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1265
                    " (i.e sampling_params.n == 1)")
1266
1267
1268
1269
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1270
1271
1272
1273
1274
1275
1276
1277
1278

                if self.scheduler_config.is_multi_step:
                    is_prefill_append = seq.data.get_num_uncomputed_tokens(
                    ) == 0
                    seq.append_token_id(sample.output_token, sample.logprobs)
                    if not is_prefill_append:
                        seq_group.update_num_computed_tokens(1)
                else:
                    seq.append_token_id(sample.output_token, sample.logprobs)
1279

1280
    def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1281
1282
        """Performs one decoding iteration and returns newly generated results.

1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

1298
            - Step 2: Calls the distributed executor to execute the model.
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
1320
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
1321
1322
1323
1324
1325
1326
1327
1328
1329
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
1330
        """
1331
1332
1333
1334
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1335

1336
        # For llm_engine, there is no pipeline parallel support, so the engine
1337
        # used is always 0.
1338
1339
        virtual_engine = 0

1340
1341
        # These are cached outputs from previous iterations. None if on first
        # iteration
1342
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1343
1344
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1345
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1346

1347
1348
        ctx = self.scheduler_contexts[virtual_engine]

1349
1350
1351
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1352
1353
1354
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
1355
1356
1357
1358
1359
        # The scheduler is also skipped if a single request caused the last
        # engine step to fail, and the previous schedule needs to be rerun.
        if not self._has_remaining_steps(
                seq_group_metadata_list
        ) and not self._skip_scheduling_next_step:
1360
            # Schedule iteration
1361
            (seq_group_metadata_list, scheduler_outputs,
1362
1363
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1364

1365
1366
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1367

1368
1369
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
1370
1371
1372
1373
1374
            # When n>1, elements in self.seq_id_to_seq_group should be deleted
            # here, otherwise memory leaks.
            for finished_request_id in finished_requests_ids:
                if finished_request_id in self.seq_id_to_seq_group:
                    del self.seq_id_to_seq_group[finished_request_id]
1375

1376
1377
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1378
                self._process_model_outputs(ctx=ctx)
1379

1380
1381
1382
1383
1384
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
1385
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
1386
                    allow_async_output_proc)
1387
1388
        else:
            finished_requests_ids = list()
1389
1390
1391

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1392

1393
        if not scheduler_outputs.is_empty():
1394
1395
1396
1397
1398
1399

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1400
                self._get_last_sampled_token_ids(virtual_engine)
1401

1402
            execute_model_req = ExecuteModelRequest(
1403
1404
1405
1406
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1407
1408
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1409
1410
1411
1412
1413
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1414
            if allow_async_output_proc:
1415
1416
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1417

1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
            try:
                outputs = self.model_executor.execute_model(
                    execute_model_req=execute_model_req)
                self._skip_scheduling_next_step = False
            except InputProcessingError as e:
                # The input for this request cannot be processed, so we must
                # abort it. If there are remaining requests in the batch that
                # have been scheduled, they will be retried on the next step.
                invalid_request_id = e.request_id
                self._abort_and_cache_schedule(
                    request_id=invalid_request_id,
                    virtual_engine=virtual_engine,
                    seq_group_metadata_list=seq_group_metadata_list,
                    scheduler_outputs=scheduler_outputs,
                    allow_async_output_proc=allow_async_output_proc)
                # Raise so the caller is notified that this request failed
                raise
1435

1436
            # We need to do this here so that last step's sampled_token_ids can
1437
1438
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
1439
                self._update_cached_scheduler_output(virtual_engine, outputs)
1440
        else:
1441
1442
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1443
1444
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1445
            # No outputs in this case
1446
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1447

1448
1449
1450
1451
1452
1453
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
1454
            # clear the cache if we have finished all the steps.
1455
1456
1457
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()

1458
1459
1460
1461
1462
1463
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1464
            # Add results to the output_queue
1465
1466
1467
1468
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1469
1470
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1471
1472
1473

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1474
                    "Async postprocessor expects only a single output set")
1475

1476
                self._advance_to_next_step(
1477
                    outputs[0], seq_group_metadata_list,
1478
                    scheduler_outputs.scheduled_seq_groups)
1479

1480
            # Check if need to run the usual non-async path
1481
            if not allow_async_output_proc:
1482
                self._process_model_outputs(ctx=ctx)
1483

1484
                # Log stats.
1485
                self.do_log_stats(scheduler_outputs, outputs)
1486

1487
1488
1489
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1490
            # Multi-step case
1491
            return ctx.request_outputs
1492

1493
        if not self.has_unfinished_requests():
1494
1495
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1496
                self._process_model_outputs(ctx=ctx)
1497
            assert len(ctx.output_queue) == 0
1498

1499
1500
1501
1502
1503
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1504
            logger.debug("Stopping remote worker execution loop.")
1505
1506
            self.model_executor.stop_remote_worker_execution_loop()

1507
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1508

1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
    def _abort_and_cache_schedule(
            self, request_id: str, virtual_engine: int,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        """Aborts a single request, and caches the scheduler outputs minus that
        request. This allows the next step to continue processing the remaining
        requests without having to re-run the scheduler."""

        # Abort the request and remove its sequence group from the current
        # schedule
        self.abort_request(request_id)
        for i, metadata in enumerate(seq_group_metadata_list):
            if metadata.request_id == request_id:
                del seq_group_metadata_list[i]
                break
        for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
            if group.seq_group.request_id == request_id:
                del scheduler_outputs.scheduled_seq_groups[i]
                break

        # If there are still other sequence groups left in the schedule, cache
        # them and flag the engine to reuse the schedule.
        if len(seq_group_metadata_list) > 0:
            self._skip_scheduling_next_step = True
            # Reuse multi-step caching logic
            self._cache_scheduler_outputs_for_multi_step(
                virtual_engine=virtual_engine,
                scheduler_outputs=scheduler_outputs,
                seq_group_metadata_list=seq_group_metadata_list,
                allow_async_output_proc=allow_async_output_proc)

1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
1556
1557
            raise AssertionError("All running sequence groups should "
                                 "have the same remaining steps.")
1558
1559
1560
1561
1562
1563

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1564
1565
1566
1567
1568
1569
1570
1571
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None

1597
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1598
1599
1600
1601
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1602
1603
1604
1605
1606
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1607
1608
1609
1610
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1611
1612
1613
1614
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1615
1616
1617
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1618
1619
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1620
1621
        """Forced log when no requests active."""
        if self.log_stats:
1622
            stats = self._get_stats(scheduler_outputs, model_output,
1623
                                    finished_before, skip)
1624
            for logger in self.stat_loggers.values():
1625
                logger.log(stats)
1626

1627
1628
1629
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1630
1631
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1632
1633
1634
1635
1636
1637
1638
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1639
1640
1641
1642
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1643
        """
1644
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1645

1646
1647
        # System State
        #   Scheduler State
1648
1649
1650
1651
1652
1653
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1654
1655

        # KV Cache Usage in %
1656
        num_total_gpu = self.cache_config.num_gpu_blocks
1657
        gpu_cache_usage_sys = 0.
1658
        if num_total_gpu:  # Guard against both None and 0
1659
1660
1661
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1662
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1663

1664
        num_total_cpu = self.cache_config.num_cpu_blocks
1665
        cpu_cache_usage_sys = 0.
1666
        if num_total_cpu:  # Guard against both None and 0
1667
1668
1669
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1670
1671
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1672
1673
1674
1675
1676
1677
1678
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1679
1680
1681
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1682
        num_tokens_iter = 0
1683
1684
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1685
1686
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1687
1688
1689
1690

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1691
1692
1693
1694
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1695
1696
1697
        time_in_queue_requests: List[float] = []
        model_forward_time_requests: List[float] = []
        model_execute_time_requests: List[float] = []
1698
1699
1700
1701
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1702
        max_num_generation_tokens_requests: List[int] = []
1703
        max_tokens_requests: List[int] = []
1704
1705
        finished_reason_requests: List[str] = []

1706
        # LoRA requests
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1725
1726
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1727
        if scheduler_outputs is not None:
1728
1729
1730
1731
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1732
            num_generation_tokens_from_prefill_groups = 0
1733
1734
1735
1736
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1737
1738
1739

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1740
1741
1742
1743
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue
1744
1745
1746
1747
1748

                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1749

1750
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1751
                seq_group = scheduled_seq_group.seq_group
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
1764
                        latency = seq_group.get_last_token_latency()
1765
1766
1767
1768
1769
1770
1771
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
1772
                    latency = seq_group.get_last_token_latency()
1773
                    time_per_output_tokens_iter.append(latency)
1774
1775
1776
1777
1778
1779
1780
1781
1782
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1783
1784
1785
1786
1787
1788

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1789
                if seq_group.is_finished():
1790
                    # Latency timings
1791
1792
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1805
1806
1807
1808
1809
1810
1811
1812
1813
                    if seq_group.metrics.time_in_queue is not None:
                        time_in_queue_requests.append(
                            seq_group.metrics.time_in_queue)
                    if seq_group.metrics.model_forward_time is not None:
                        model_forward_time_requests.append(
                            seq_group.metrics.model_forward_time)
                    if seq_group.metrics.model_execute_time is not None:
                        model_execute_time_requests.append(
                            seq_group.metrics.model_execute_time * 1000)
1814
1815
1816
1817
1818
1819
1820
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1821
1822
1823
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1824
1825
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1826
1827
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1840
                actual_num_batched_tokens - num_prompt_tokens_iter +
1841
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1842
1843
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1844
1845
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
1846
1847
        if model_output and isinstance(model_output[0], SamplerOutput) and (
                model_output[0].spec_decode_worker_metrics is not None):
1848
1849
1850
1851
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1852
1853
        return Stats(
            now=now,
1854
1855
1856
1857
1858
1859
1860
1861
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1862
1863
1864
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1865
1866
1867
1868

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1869
            num_tokens_iter=num_tokens_iter,
1870
1871
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1872
            spec_decode_metrics=spec_decode_metrics,
1873
            num_preemption_iter=num_preemption_iter,
1874
1875
1876
1877

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1878
1879
1880
1881
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1882
1883
1884
            time_in_queue_requests=time_in_queue_requests,
            model_forward_time_requests=model_forward_time_requests,
            model_execute_time_requests=model_execute_time_requests,
1885
1886
1887
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1888
1889
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
1890
            n_requests=n_requests,
1891
            max_tokens_requests=max_tokens_requests,
1892
            finished_reason_requests=finished_reason_requests,
1893
1894
1895
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1896

1897
    def add_lora(self, lora_request: LoRARequest) -> bool:
1898
        return self.model_executor.add_lora(lora_request)
1899
1900

    def remove_lora(self, lora_id: int) -> bool:
1901
        return self.model_executor.remove_lora(lora_id)
1902

1903
    def list_loras(self) -> Set[int]:
1904
        return self.model_executor.list_loras()
1905

1906
1907
1908
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

1919
1920
1921
1922
1923
1924
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

1925
1926
1927
1928
1929
    def sleep(self, level: int = 1) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.sleep(level=level)

1930
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
1931
1932
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
1933
        self.model_executor.wake_up(tags)
1934

1935
1936
1937
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

1938
    def check_health(self) -> None:
1939
        self.model_executor.check_health()
1940
1941
1942
1943

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

1944
1945
1946
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
1947
1948
1949
        if self.tracer is None:
            return

1950
1951
1952
1953
1954
1955
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
1975
            seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
1976
                                   self.model_config.model)
1977
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
1978
                                   seq_group.request_id)
1979
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
1980
                                   seq_group.sampling_params.temperature)
1981
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
1982
                                   seq_group.sampling_params.top_p)
1983
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
1984
                                   seq_group.sampling_params.max_tokens)
1985
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
1986
                                   seq_group.sampling_params.n)
1987
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
1988
                                   seq_group.num_seqs())
1989
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
1990
1991
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
1992
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
1993
1994
1995
1996
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
1997
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
1998
1999
                                   metrics.time_in_queue)
            seq_span.set_attribute(
2000
2001
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
2002
2003
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
2004
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
2005
2006
2007
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
2008
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
2009
2010
2011
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
2012
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
2013
                    metrics.model_execute_time)
2014

2015
    def _validate_model_inputs(self, inputs: ProcessorInputs,
2016
                               lora_request: Optional[LoRARequest]):
2017
2018
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

2019
2020
2021
2022
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs,
                                       lora_request,
                                       prompt_type="encoder")
2023

2024
2025
2026
        self._validate_model_input(decoder_inputs,
                                   lora_request,
                                   prompt_type="decoder")
2027

2028
2029
2030
2031
2032
2033
2034
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        lora_request: Optional[LoRARequest],
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
2035
2036
2037
        model_config = self.model_config
        tokenizer = (None if self.tokenizer is None else
                     self.tokenizer.get_lora_tokenizer(lora_request))
2038

2039
2040
2041
2042
2043
2044
2045
2046
        prompt_ids = prompt_inputs["prompt_token_ids"]
        if not prompt_ids:
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")

        max_prompt_len = self.model_config.max_model_len
2047
        if len(prompt_ids) > max_prompt_len:
2048
            if prompt_type == "encoder" and model_config.is_multimodal_model:
2049
2050
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
2051
2052
2053
                    model_config,
                    tokenizer=tokenizer or object(),  # Dummy if no tokenizer
                )
2054
                assert isinstance(mm_processor, EncDecMultiModalProcessor)
2055

2056
2057
2058
                if mm_processor.pad_dummy_encoder_prompt:
                    return  # Skip encoder length check for Whisper

2059
            if model_config.is_multimodal_model:
2060
                suggestion = (
2061
2062
2063
2064
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
2065
2066
2067
2068
2069
2070
2071
2072
2073
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
2074
2075
2076
2077

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
        """Constructs logits processors based on the guided_decoding,
        logits_bias, and allowed_token_ids fields in sampling_params. Deletes
        those fields and adds the constructed logits processors to the
        logits_processors field. Returns the modified sampling params."""

        logits_processors = []
2088

2089
2090
2091
2092
2093
        if sampling_params.guided_decoding is not None:
            # Defensively copy sampling params since guided decoding logits
            # processors can have different state for each request
            sampling_params = copy.copy(sampling_params)
            guided_decoding = sampling_params.guided_decoding
2094
2095
2096
2097
2098
2099
2100
2101
2102

            logger.debug(
                "Building guided decoding logits processor in "
                "LLMEngine. Params: %s", guided_decoding)

            tokenizer = self.get_tokenizer(lora_request=lora_request)
            guided_decoding.backend = guided_decoding.backend or \
                self.decoding_config.guided_decoding_backend

2103
2104
2105
            if self.decoding_config.reasoning_backend is not None:
                logger.debug("Building with reasoning backend %s",
                             self.decoding_config.reasoning_backend)
2106

2107
            processor = get_local_guided_decoding_logits_processor(
2108
2109
                guided_params=guided_decoding,
                tokenizer=tokenizer,
2110
2111
2112
                model_config=self.model_config,
                reasoning_backend=self.decoding_config.reasoning_backend,
            )
2113
2114
2115
2116
2117
2118
2119
2120
2121
            if processor:
                logits_processors.append(processor)

            # Unset so this doesn't get passed down to the model
            sampling_params.guided_decoding = None

        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

2122
            processors = get_openai_logits_processors(
2123
2124
2125
2126
2127
2128
2129
2130
2131
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

2132
2133
2134
2135
2136
2137
        if len(sampling_params.bad_words) > 0:
            tokenizer = self.get_tokenizer(lora_request)
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

2138
2139
2140
2141
2142
2143
2144
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params
2145

2146
2147
2148
2149
2150
2151
2152
2153
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

2154

2155
2156
2157
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
    LLMEngine = V1LLMEngine  # type: ignore