llm_engine.py 91.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import copy
Antoni Baum's avatar
Antoni Baum committed
4
import time
5
from collections import Counter as collectionsCounter
6
from collections import deque
7
from contextlib import contextmanager
8
from dataclasses import dataclass
9
from functools import partial
10
11
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
                    List, Mapping, NamedTuple, Optional)
12
from typing import Sequence as GenericSequence
13
from typing import Set, Type, Union, cast, overload
14

15
import torch
16
from typing_extensions import TypeVar, deprecated
17

18
import vllm.envs as envs
19
20
21
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ObservabilityConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
22
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.engine.arg_utils import EngineArgs
24
from vllm.engine.metrics_types import StatLoggerBase, Stats
25
26
27
28
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
29
30
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
31
from vllm.executor.executor_base import ExecutorBase
32
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
33
                         PromptType, SingletonInputsAdapter)
34
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
35
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.logger import init_logger
37
from vllm.logits_process import get_bad_words_logits_processors
38
from vllm.lora.request import LoRARequest
39
40
from vllm.model_executor.guided_decoding import (
    get_local_guided_decoding_logits_processor)
41
from vllm.model_executor.layers.sampler import SamplerOutput
42
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
43
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
44
45
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
46
from vllm.prompt_adapter.request import PromptAdapterRequest
47
from vllm.sampling_params import RequestOutputKind, SamplingParams
48
49
50
51
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
                           PoolingSequenceGroupOutput, Sequence, SequenceGroup,
                           SequenceGroupBase, SequenceGroupMetadata,
                           SequenceGroupOutput, SequenceStatus)
52
53
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
54
from vllm.transformers_utils.detokenizer import Detokenizer
55
from vllm.transformers_utils.tokenizer import AnyTokenizer
56
from vllm.transformers_utils.tokenizer_group import (
57
    BaseTokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
58
59
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
60
61
from vllm.utils import (Counter, Device, deprecate_kwargs,
                        resolve_obj_by_qualname, weak_bind)
62
from vllm.version import __version__ as VLLM_VERSION
63
from vllm.worker.model_runner_base import InputProcessingError
64
65

logger = init_logger(__name__)
66
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
67

68
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
69
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
70
71


72
73
74
75
76
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
77
78
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
79
80


81
82
83
84
85
86
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
87
88
89
90
91
92
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
93
94
95
    skip: List[int]


96
class SchedulerContext:
97

98
    def __init__(self, multi_step_stream_outputs: bool = False):
99
100
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
101
                                         PoolingRequestOutput]] = []
102
103
104
105
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

106
107
        self.multi_step_stream_outputs: bool = multi_step_stream_outputs

108
109
110
    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
111
112
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
113
114
115
116
117
118
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
119
                       is_first_step_output=is_first_step_output,
120
                       skip=[]))
121
122


123
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
124
    """An LLM engine that receives requests and generates texts.
125

Woosuk Kwon's avatar
Woosuk Kwon committed
126
    This is the main class for the vLLM engine. It receives requests
127
128
129
130
131
132
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

133
134
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
135

136
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
137
    :ref:`engine-args`)
138
139
140
141
142
143
144

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
145
        device_config: The configuration related to the device.
146
147
148
        lora_config (Optional): The configuration related to serving multi-LoRA.
        speculative_config (Optional): The configuration related to speculative
            decoding.
149
150
        executor_class: The model executor class for managing distributed
            execution.
151
        prompt_adapter_config (Optional): The configuration related to serving
152
            prompt adapters.
153
        log_stats: Whether to log statistics.
154
        usage_context: Specified entry point, used for usage info collection.
155
    """
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

182
        return cast(_O, output)
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

208
209
    def __init__(
        self,
210
        vllm_config: VllmConfig,
211
        executor_class: Type[ExecutorBase],
212
        log_stats: bool,
yhu422's avatar
yhu422 committed
213
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
214
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
215
        input_registry: InputRegistry = INPUT_REGISTRY,
216
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
217
        use_cached_outputs: bool = False,
218
    ) -> None:
219
220
221
222
223
224
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
225

226
        self.vllm_config = vllm_config
227
228
229
230
231
232
233
234
235
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
236
        )
237
238
        self.prompt_adapter_config = vllm_config.prompt_adapter_config  # noqa
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
239
240
        )

241
        logger.info(
242
            "Initializing a V0 LLM engine (v%s) with config: %s, "
243
            "use_cached_outputs=%s, ",
244
            VLLM_VERSION,
245
            vllm_config,
246
            use_cached_outputs,
247
        )
248

249
        self.log_stats = log_stats
250
        self.use_cached_outputs = use_cached_outputs
251

252
        if not self.model_config.skip_tokenizer_init:
253
            self.tokenizer = self._init_tokenizer()
254
            self.detokenizer = Detokenizer(self.tokenizer)
255
            tokenizer_group = self.get_tokenizer_group()
256
257
        else:
            self.tokenizer = None
258
            self.detokenizer = None
259
260
261
262
263
264
265
266
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
267

268
        self.seq_counter = Counter()
269
270
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
271

272
        self.input_preprocessor = InputPreprocessor(self.model_config,
273
274
                                                    self.tokenizer,
                                                    mm_registry)
275

276
277
        self.input_registry = input_registry
        self.input_processor = input_registry.create_input_processor(
278
            self.model_config)
279

280
        self.model_executor = executor_class(vllm_config=vllm_config, )
281

282
        if self.model_config.runner_type != "pooling":
283
            self._initialize_kv_caches()
284

yhu422's avatar
yhu422 committed
285
286
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
287
288
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
289
            usage_message.report_usage(
290
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
291
292
293
294
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
295
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
296
                    "tensor_parallel_size":
297
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
298
                    "block_size":
299
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
300
                    "gpu_memory_utilization":
301
                    self.cache_config.gpu_memory_utilization,
yhu422's avatar
yhu422 committed
302
303
304

                    # Quantization
                    "quantization":
305
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
306
                    "kv_cache_dtype":
307
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
308
309
310

                    # Feature flags
                    "enable_lora":
311
                    bool(self.lora_config),
312
                    "enable_prompt_adapter":
313
                    bool(self.prompt_adapter_config),
yhu422's avatar
yhu422 committed
314
                    "enable_prefix_caching":
315
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
316
                    "enforce_eager":
317
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
318
                    "disable_custom_all_reduce":
319
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
320
321
                })

322
323
324
325
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
326

327
328
329
330
331
332
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
333
334
            SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
                             multi_step_stream_outputs)
335
336
337
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

338
        if self.model_config.use_async_output_proc:
339
340
341
342
343
344
345
346
347
            process_model_outputs = weak_bind(self._process_model_outputs)

            self.async_callbacks = [
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
        else:
            self.async_callbacks = []
348
349
350

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
351
        self.process_request_outputs_callback: Optional[Callable] = None
352

353
        # Create the scheduler.
354
355
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
356
357
358
359
360
        if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
            Scheduler = resolve_obj_by_qualname(
                self.vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = self.vllm_config.scheduler_config.scheduler_cls
361
        self.scheduler = [
362
            Scheduler(
363
364
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
365
                self.async_callbacks[v_id]
366
367
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
368
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
369

370
371
        # Metric Logging.
        if self.log_stats:
372
373
374
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
375
376
377
378
379
380
381
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

382
383
384
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
385
386
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        vllm_config=vllm_config),
387
388
389
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
390
391
                        labels=dict(
                            model_name=self.model_config.served_model_name),
392
                        vllm_config=vllm_config),
393
394
395
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
396

397
398
399
400
401
402
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

403
404
405
406
407
408
409
410
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
411
                get_tokenizer_for_seq,
412
413
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
414
                    get_tokenizer_for_seq,
415
416
417
                ),
            ))

418
419
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

420
421
422
423
        # Flag to set when an input fails to process and the engine should run
        # the next step without re-scheduling.
        self._skip_scheduling_next_step = False

424
425
426
427
428
429
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
430
        start = time.time()
431
432
433
434
435
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
436
437
438
439
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
440
441
442
443
444
445
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
446
447
448
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
449

450
    @classmethod
451
    def _get_executor_cls(cls,
452
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
453
        # distributed_executor_backend must be set in VllmConfig.__post_init__
454
455
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
456
        # Initialize the cluster and specify the executor class.
457
458
459
460
461
462
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
463
464
465
466
467
468
469
470
471
472
473
474
475
        elif distributed_executor_backend == "ray":
            from vllm.executor.ray_distributed_executor import (
                RayDistributedExecutor)
            executor_class = RayDistributedExecutor
        elif distributed_executor_backend == "mp":
            from vllm.executor.mp_distributed_executor import (
                MultiprocessingDistributedExecutor)
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
            executor_class = MultiprocessingDistributedExecutor
        elif distributed_executor_backend == "uni":
            # JAX-style, single-process, multi-device executor.
476
477
            from vllm.executor.uniproc_executor import UniProcExecutor
            executor_class = UniProcExecutor
478
479
480
481
482
483
484
485
        elif distributed_executor_backend == "external_launcher":
            # executor with external launcher
            from vllm.executor.uniproc_executor import (  # noqa
                ExecutorWithExternalLauncher)
            executor_class = ExecutorWithExternalLauncher
        else:
            raise ValueError("unrecognized distributed_executor_backend: "
                             f"{distributed_executor_backend}")
486
487
        return executor_class

488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

504
505
506
507
508
509
510
511
512
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
513
514
515
516
517
518
519
520
521
        vllm_config = engine_args.create_engine_config(usage_context)

        engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            engine_cls = V1LLMEngine

        return engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
522
            usage_context=usage_context,
523
            stat_loggers=stat_loggers,
524
            disable_log_stats=engine_args.disable_log_stats,
yhu422's avatar
yhu422 committed
525
        )
526

527
528
529
530
531
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

532
533
534
535
536
537
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

538
    def get_tokenizer_group(
539
540
541
542
543
544
        self,
        group_type: Type[_G] = BaseTokenizerGroup,
    ) -> _G:
        tokenizer_group = self.tokenizer

        if tokenizer_group is None:
545
546
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
547
548
549
550
        if not isinstance(tokenizer_group, group_type):
            raise TypeError("Invalid type of tokenizer group. "
                            f"Expected type: {group_type}, but "
                            f"found type: {type(tokenizer_group)}")
551

552
        return tokenizer_group
553

554
    def get_tokenizer(
555
556
557
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
558
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
559

560
561
562
563
564
    def _init_tokenizer(self) -> BaseTokenizerGroup:
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
            parallel_config=self.parallel_config,
565
            lora_config=self.lora_config)
566

567
568
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
569
        self.cache_config.verify_with_parallel_config(self.parallel_config)
570
571
572
573
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
574
575
576
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
577

578
579
580
    def _add_processed_request(
        self,
        request_id: str,
581
        processed_inputs: ProcessorInputs,
582
583
584
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
585
        prompt_adapter_request: Optional[PromptAdapterRequest],
586
        trace_headers: Optional[Mapping[str, str]] = None,
587
        priority: int = 0,
588
    ) -> Optional[SequenceGroup]:
589
590
591
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
            )
            return None

606
        self._validate_model_inputs(processed_inputs, lora_request)
607
608
609
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
610
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
611

612
613
614
615
616
617
618
619
        if is_encoder_decoder_inputs(processed_inputs):
            decoder_inputs = processed_inputs["decoder"]
            encoder_inputs = processed_inputs["encoder"]
        else:
            decoder_inputs = processed_inputs
            encoder_inputs = None

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
620
                       lora_request, prompt_adapter_request)
621

622
623
624
        encoder_seq = (None if encoder_inputs is None else Sequence(
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
            prompt_adapter_request))
625

626
627
628
629
630
631
632
633
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
634
                trace_headers=trace_headers,
635
                prompt_adapter_request=prompt_adapter_request,
636
637
                encoder_seq=encoder_seq,
                priority=priority)
638
639
640
641
642
643
644
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
645
                prompt_adapter_request=prompt_adapter_request,
646
647
                encoder_seq=encoder_seq,
                priority=priority)
648
649
650
651
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

652
653
654
655
656
657
658
659
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

660
661
        return seq_group

662
663
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
664

665
    @overload
666
667
668
    def add_request(
        self,
        request_id: str,
669
        prompt: PromptType,
670
        params: Union[SamplingParams, PoolingParams],
671
        arrival_time: Optional[float] = None,
672
        lora_request: Optional[LoRARequest] = None,
673
        trace_headers: Optional[Mapping[str, str]] = None,
674
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
675
        priority: int = 0,
676
    ) -> None:
677
678
679
        ...

    @overload
680
    @deprecated("'inputs' will be renamed to 'prompt")
681
682
683
    def add_request(
        self,
        request_id: str,
684
685
        *,
        inputs: PromptType,
686
687
688
689
690
691
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
692
    ) -> None:
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    def add_request(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
            priority: int = 0,
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
711
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
712
        """Add a request to the engine's request pool.
713
714

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
715
        scheduler as `engine.step()` is called. The exact scheduling policy is
716
717
718
719
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
720
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
721
722
723
724
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
725
            arrival_time: The arrival time of the request. If None, we use
726
                the current monotonic time.
727
            lora_request: The LoRA request to add.
728
            trace_headers: OpenTelemetry trace headers.
729
            prompt_adapter_request: The prompt adapter request to add.
730
731
            priority: The priority of the request.
                Only applicable with priority scheduling.
732
733
734
735

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
736
            - Create `n` number of :class:`~vllm.Sequence` objects.
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
756
        """
757
758
759
760
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

761
762
763
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
764

765
        if priority != 0 and not self.scheduler_config.policy == "priority":
766
767
768
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

769
770
771
772
773
774
775
        if isinstance(params, SamplingParams) \
            and (params.guided_decoding or params.logits_processors) \
            and self.scheduler_config.num_scheduler_steps > 1:
            raise ValueError(
                "Guided decoding and logits processors are not supported "
                "in multi-step decoding")

776
        if arrival_time is None:
777
            arrival_time = time.time()
778

779
780
781
782
783
        if self.tokenizer is not None:
            self._validate_token_prompt(
                prompt,
                tokenizer=self.get_tokenizer(lora_request=lora_request))

784
        preprocessed_inputs = self.input_preprocessor.preprocess(
785
            prompt,
786
787
            request_id=request_id,
            lora_request=lora_request,
788
789
            prompt_adapter_request=prompt_adapter_request,
        )
790
        processed_inputs = self.input_processor(preprocessed_inputs)
791

792
        self._add_processed_request(
793
794
795
796
797
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
798
            prompt_adapter_request=prompt_adapter_request,
799
            trace_headers=trace_headers,
800
            priority=priority,
801
        )
802

803
804
805
806
807
808
809
810
811
812
813
    def _validate_token_prompt(self, prompt: PromptType,
                               tokenizer: AnyTokenizer):
        # Guard against out-of-vocab tokens.
        # For some tokenizers, tokenizer.decode will happily return empty text
        # for token ids that are out of vocab, and we don't detect token ids
        # that are greater than the max token id before running the model.
        # However, these token ids will later crash a cuda kernel at runtime
        # with an index out of bounds error. This will crash the entire engine.
        # This needs to happen before multimodal input pre-processing, which
        # may add dummy <image> tokens that aren't part of the tokenizer's
        # vocabulary.
814
        if is_token_prompt(prompt):
815
816
817
818
819
820
821
822
823
            prompt_ids = prompt["prompt_token_ids"]
            if len(prompt_ids) == 0:
                # Empty prompt check is handled later
                return
            max_input_id = max(prompt_ids)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    "Token id {} is out of vocabulary".format(max_input_id))

824
825
826
827
828
    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
829
830
        arrival_time: float,
        lora_request: Optional[LoRARequest],
831
        trace_headers: Optional[Mapping[str, str]] = None,
832
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
833
        encoder_seq: Optional[Sequence] = None,
834
        priority: int = 0,
835
836
837
838
839
840
841
842
843
844
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

845
846
847
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

848
849
850
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
851

852
        sampling_params.update_from_generation_config(
853
            self.generation_config_fields, seq.eos_token_id)
854

855
        # Create the sequence group.
856
857
858
859
860
861
862
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
863
            prompt_adapter_request=prompt_adapter_request,
864
865
            encoder_seq=encoder_seq,
            priority=priority)
866

867
868
869
870
871
872
873
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
874
875
        arrival_time: float,
        lora_request: Optional[LoRARequest],
876
        prompt_adapter_request: Optional[PromptAdapterRequest],
877
        encoder_seq: Optional[Sequence] = None,
878
        priority: int = 0,
879
880
881
882
883
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
884
885
886
887
888
889
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
890
            prompt_adapter_request=prompt_adapter_request,
891
892
            encoder_seq=encoder_seq,
            priority=priority)
893
        return seq_group
894

Antoni Baum's avatar
Antoni Baum committed
895
896
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
897
898

        Args:
Antoni Baum's avatar
Antoni Baum committed
899
            request_id: The ID(s) of the request to abort.
900
901
902
903
904
905
906
907
908
909
910

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
911
        """
912
        for scheduler in self.scheduler:
913
914
            scheduler.abort_seq_group(
                request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
915

916
917
918
919
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

920
921
922
923
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

924
925
926
927
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

928
929
930
931
932
933
934
935
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

936
    def get_num_unfinished_requests(self) -> int:
937
        """Gets the number of unfinished requests."""
938
939
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
940

941
    def has_unfinished_requests(self) -> bool:
942
        """Returns True if there are unfinished requests."""
943
944
945
946
947
948
949
950
951
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
952

953
954
955
956
957
958
959
960
    def reset_prefix_cache(self) -> bool:
        """Reset prefix cache for all devices."""

        success = True
        for scheduler in self.scheduler:
            success = success and scheduler.reset_prefix_cache()
        return success

961
    @staticmethod
962
963
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
964
        outputs: List[PoolingSequenceGroupOutput],
965
    ) -> None:
966
        seq_group.pooled_data = outputs[0].data
967
968
969
970
971
972

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

973
974
975
976
977
978
979
980
    def _update_num_computed_tokens_for_multi_step_prefill(
            self, seq_group: SequenceGroup,
            seq_group_meta: SequenceGroupMetadata,
            is_first_step_output: Optional[bool]):
        """
        This function updates num_computed_tokens for prompt sequences
        when Multi-Step is enabled.

981
        seq_group: SequenceGroup to update the num_computed_tokens for.
982
        seq_group_meta: Metadata of the given SequenceGroup.
983
        is_first_step_output: Optional[bool] -
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
            When available, is_first_step_output indicates if the appended
            output token is the output of the first-step in multi-step.
            A value of None indicates that outputs from all steps in
            in multi-step are submitted in a single burst.
        """

        assert self.scheduler_config.is_multi_step

        if not seq_group_meta.is_prompt:
            # num_computed_token updates for multi-step decodes happen after
            # the tokens are appended to the sequence.
            return

        do_update: bool = False
        if self.scheduler_config.chunked_prefill_enabled:
            # In multi-step + chunked-prefill case, the prompt sequences
            # that are scheduled are fully processed in the first step.
            do_update = is_first_step_output is None or is_first_step_output
        else:
            # Normal multi-step decoding case. In this case prompt-sequences
            # are actually single-stepped. Always update in this case.
            assert seq_group.state.num_steps == 1
            do_update = True

        if do_update:
            seq_group.update_num_computed_tokens(
                seq_group_meta.token_chunk_size)

1012
1013
1014
1015
1016
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
1017

1018
1019
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
1020
        """
1021

1022
        now = time.time()
1023

1024
        if len(ctx.output_queue) == 0:
1025
1026
            return None

1027
        # Get pending async postprocessor
1028
1029
1030
1031
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1032
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
1033
1034
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1035
1036
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
1037
1038
1039
1040
1041

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

1042
        has_multiple_outputs: bool = len(outputs) > 1
1043
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
1044
1045
1046
1047
1048
        if has_multiple_outputs:
            assert self.scheduler_config.is_multi_step or \
                     self.speculative_config
            # Organize outputs by [step][sequence group] instead of
            # [sequence group][step].
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
            if self.scheduler_config.is_multi_step:
                outputs_by_sequence_group = create_output_by_sequence_group(
                    outputs, len(seq_group_metadata_list))
            elif self.speculative_config:
                # Decodes are multi-steps while prefills are not, outputting at
                # most 1 token. Separate them so that we can trigger chunk
                # processing without having to pad or copy over prompts K times
                # to match decodes structure (costly with prompt_logprobs).
                num_prefills = sum(sg.is_prompt
                                   for sg in seq_group_metadata_list)
                prefills, decodes = outputs[:num_prefills], outputs[
                    num_prefills:]
                outputs_by_sequence_group = create_output_by_sequence_group(
                    decodes,
                    num_seq_groups=len(seq_group_metadata_list) - num_prefills)
                outputs_by_sequence_group = [p.outputs for p in prefills
                                             ] + outputs_by_sequence_group
1066
1067
1068
            # We have outputs for multiple steps submitted in a single burst,
            # so invalidate is_first_step_output.
            is_first_step_output = None
1069
1070
1071
        else:
            outputs_by_sequence_group = outputs

1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

1089
        finished_before: List[int] = []
1090
        finished_now: List[int] = []
1091
1092
1093
1094
1095
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
1096
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1097

1098
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
1099
1100
1101
1102
1103

            if seq_group.is_finished():
                finished_before.append(i)
                continue

1104
            output: List[SequenceGroupOutput]
1105
            if has_multiple_outputs:
1106
1107
1108
1109
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

1110
1111
1112
1113
1114
1115
1116
            if not is_async:
                if self.scheduler_config.is_multi_step:
                    # Updates happen only if the sequence is prefill
                    self._update_num_computed_tokens_for_multi_step_prefill(
                        seq_group, seq_group_meta, is_first_step_output)
                else:
                    seq_group.update_num_computed_tokens(
1117
                        seq_group_meta.token_chunk_size or 0)
1118
1119
1120

            if outputs:
                for o in outputs:
1121
1122
1123
1124
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
1125
                                o.model_forward_time or 0)
1126
1127
1128
1129
1130
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
1131
                                o.model_execute_time or 0)
1132
1133
1134
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1135

1136
            if self.model_config.runner_type == "pooling":
1137
                self._process_sequence_group_outputs(seq_group, output)
1138
1139
1140
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
1141
                    self.output_processor.process_outputs(
1142
                        seq_group, output, is_async)
1143

1144
1145
            if seq_group.is_finished():
                finished_now.append(i)
1146

1147
1148
1149
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1150

1151
1152
            seq_group = scheduled_seq_group.seq_group
            seq_group.maybe_set_first_token_time(now)
1153
1154
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1155
            request_output = RequestOutputFactory.create(
1156
1157
1158
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1159
1160
            if request_output:
                ctx.request_outputs.append(request_output)
1161

1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1174
1175
1176
1177
1178
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

1179
1180
        # For multi-step without streaming, don't create outputs each iteration
        if not is_last_step and not ctx.multi_step_stream_outputs:
1181
1182
1183
1184
            # Immediately process request outputs here (if callback is given)
            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
1185
                ctx.request_outputs.clear()
1186
1187
1188
            return

        # Create the outputs
1189
1190
        for i in indices:
            if i in skip or i in finished_before or i in finished_now:
1191
1192
                continue  # Avoids double processing

1193
1194
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1195
            seq_group = scheduled_seq_group.seq_group
1196
            seq_group.maybe_set_first_token_time(now)
1197
1198
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1199
            request_output = RequestOutputFactory.create(
1200
1201
1202
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1203
            if request_output:
1204
                ctx.request_outputs.append(request_output)
1205

1206
1207
1208
1209
1210
1211
1212
1213
        # For multi-step with streaming, create outputs each iteration
        if not is_last_step and ctx.multi_step_stream_outputs:
            # Immediately process request outputs here (if callback is given)
            if self.process_request_outputs_callback is not None:
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1214
        for seq_group in scheduler_outputs.ignored_seq_groups:
1215
1216
1217
1218
1219
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1220
            request_output = RequestOutputFactory.create(
1221
1222
1223
1224
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
1225
1226
            if request_output:
                ctx.request_outputs.append(request_output)
1227

1228
1229
1230
1231
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1232
            ctx.request_outputs.clear()
1233

1234
1235
1236
1237
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1238
            # Log stats.
1239
1240
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1241
1242

            # Tracing
1243
            self.do_tracing(scheduler_outputs, finished_before)
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261

        return None

    def _advance_to_next_step(
            self, output: List[SamplerOutput],
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1262
1263
1264
1265
1266
            if self.scheduler_config.is_multi_step:
                # Updates happen only if the sequence is prefill
                self._update_num_computed_tokens_for_multi_step_prefill(
                    seq_group, seq_group_metadata,
                    seq_group.state.num_steps == 1)
1267
            else:
1268
1269
1270
1271
                token_chunk_size = (seq_group_metadata.token_chunk_size
                                    if seq_group_metadata.token_chunk_size
                                    is not None else 0)
                seq_group.update_num_computed_tokens(token_chunk_size)
1272

1273
1274
1275
            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1276
                    " (i.e sampling_params.n == 1)")
1277
1278
1279
1280
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1281
1282
1283
1284
1285
1286
1287
1288
1289

                if self.scheduler_config.is_multi_step:
                    is_prefill_append = seq.data.get_num_uncomputed_tokens(
                    ) == 0
                    seq.append_token_id(sample.output_token, sample.logprobs)
                    if not is_prefill_append:
                        seq_group.update_num_computed_tokens(1)
                else:
                    seq.append_token_id(sample.output_token, sample.logprobs)
1290

1291
    def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1292
1293
        """Performs one decoding iteration and returns newly generated results.

1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

1309
            - Step 2: Calls the distributed executor to execute the model.
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
1331
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
1332
1333
1334
1335
1336
1337
1338
1339
1340
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
1341
        """
1342
1343
1344
1345
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1346

1347
        # For llm_engine, there is no pipeline parallel support, so the engine
1348
        # used is always 0.
1349
1350
        virtual_engine = 0

1351
1352
        # These are cached outputs from previous iterations. None if on first
        # iteration
1353
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1354
1355
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1356
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1357

1358
1359
        ctx = self.scheduler_contexts[virtual_engine]

1360
1361
1362
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1363
1364
1365
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
1366
1367
1368
1369
1370
        # The scheduler is also skipped if a single request caused the last
        # engine step to fail, and the previous schedule needs to be rerun.
        if not self._has_remaining_steps(
                seq_group_metadata_list
        ) and not self._skip_scheduling_next_step:
1371
            # Schedule iteration
1372
            (seq_group_metadata_list, scheduler_outputs,
1373
1374
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1375

1376
1377
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1378

1379
1380
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
1381
1382
1383
1384
1385
            # When n>1, elements in self.seq_id_to_seq_group should be deleted
            # here, otherwise memory leaks.
            for finished_request_id in finished_requests_ids:
                if finished_request_id in self.seq_id_to_seq_group:
                    del self.seq_id_to_seq_group[finished_request_id]
1386

1387
1388
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1389
                self._process_model_outputs(ctx=ctx)
1390

1391
1392
1393
1394
1395
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
1396
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
1397
                    allow_async_output_proc)
1398
1399
        else:
            finished_requests_ids = list()
1400
1401
1402

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1403

1404
        if not scheduler_outputs.is_empty():
1405
1406
1407
1408
1409
1410

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1411
                self._get_last_sampled_token_ids(virtual_engine)
1412

1413
            execute_model_req = ExecuteModelRequest(
1414
1415
1416
1417
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1418
1419
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1420
1421
1422
1423
1424
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1425
            if allow_async_output_proc:
1426
1427
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1428

1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
            try:
                outputs = self.model_executor.execute_model(
                    execute_model_req=execute_model_req)
                self._skip_scheduling_next_step = False
            except InputProcessingError as e:
                # The input for this request cannot be processed, so we must
                # abort it. If there are remaining requests in the batch that
                # have been scheduled, they will be retried on the next step.
                invalid_request_id = e.request_id
                self._abort_and_cache_schedule(
                    request_id=invalid_request_id,
                    virtual_engine=virtual_engine,
                    seq_group_metadata_list=seq_group_metadata_list,
                    scheduler_outputs=scheduler_outputs,
                    allow_async_output_proc=allow_async_output_proc)
                # Raise so the caller is notified that this request failed
                raise
1446

1447
            # We need to do this here so that last step's sampled_token_ids can
1448
1449
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
1450
                self._update_cached_scheduler_output(virtual_engine, outputs)
1451
        else:
1452
1453
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1454
1455
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1456
            # No outputs in this case
1457
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1458

1459
1460
1461
1462
1463
1464
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
1465
            # clear the cache if we have finished all the steps.
1466
1467
1468
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()

1469
1470
1471
1472
1473
1474
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1475
            # Add results to the output_queue
1476
1477
1478
1479
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1480
1481
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1482
1483
1484

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1485
                    "Async postprocessor expects only a single output set")
1486

1487
                self._advance_to_next_step(
1488
                    outputs[0], seq_group_metadata_list,
1489
                    scheduler_outputs.scheduled_seq_groups)
1490

1491
            # Check if need to run the usual non-async path
1492
            if not allow_async_output_proc:
1493
                self._process_model_outputs(ctx=ctx)
1494

1495
                # Log stats.
1496
                self.do_log_stats(scheduler_outputs, outputs)
1497

1498
1499
1500
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1501
            # Multi-step case
1502
            return ctx.request_outputs
1503

1504
        if not self.has_unfinished_requests():
1505
1506
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1507
                self._process_model_outputs(ctx=ctx)
1508
            assert len(ctx.output_queue) == 0
1509

1510
1511
1512
1513
1514
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1515
            logger.debug("Stopping remote worker execution loop.")
1516
1517
            self.model_executor.stop_remote_worker_execution_loop()

1518
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1519

1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
    def _abort_and_cache_schedule(
            self, request_id: str, virtual_engine: int,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        """Aborts a single request, and caches the scheduler outputs minus that
        request. This allows the next step to continue processing the remaining
        requests without having to re-run the scheduler."""

        # Abort the request and remove its sequence group from the current
        # schedule
        self.abort_request(request_id)
        for i, metadata in enumerate(seq_group_metadata_list):
            if metadata.request_id == request_id:
                del seq_group_metadata_list[i]
                break
        for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
            if group.seq_group.request_id == request_id:
                del scheduler_outputs.scheduled_seq_groups[i]
                break

        # If there are still other sequence groups left in the schedule, cache
        # them and flag the engine to reuse the schedule.
        if len(seq_group_metadata_list) > 0:
            self._skip_scheduling_next_step = True
            # Reuse multi-step caching logic
            self._cache_scheduler_outputs_for_multi_step(
                virtual_engine=virtual_engine,
                scheduler_outputs=scheduler_outputs,
                seq_group_metadata_list=seq_group_metadata_list,
                allow_async_output_proc=allow_async_output_proc)

1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
1567
1568
            raise AssertionError("All running sequence groups should "
                                 "have the same remaining steps.")
1569
1570
1571
1572
1573
1574

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1575
1576
1577
1578
1579
1580
1581
1582
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None

1608
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1609
1610
1611
1612
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1613
1614
1615
1616
1617
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1618
1619
1620
1621
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1622
1623
1624
1625
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1626
1627
1628
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1629
1630
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1631
1632
        """Forced log when no requests active."""
        if self.log_stats:
1633
            stats = self._get_stats(scheduler_outputs, model_output,
1634
                                    finished_before, skip)
1635
            for logger in self.stat_loggers.values():
1636
                logger.log(stats)
1637

1638
1639
1640
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1641
1642
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1643
1644
1645
1646
1647
1648
1649
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1650
1651
1652
1653
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1654
        """
1655
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1656

1657
1658
        # System State
        #   Scheduler State
1659
1660
1661
1662
1663
1664
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1665
1666

        # KV Cache Usage in %
1667
        num_total_gpu = self.cache_config.num_gpu_blocks
1668
        gpu_cache_usage_sys = 0.
1669
        if num_total_gpu:  # Guard against both None and 0
1670
1671
1672
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1673
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1674

1675
        num_total_cpu = self.cache_config.num_cpu_blocks
1676
        cpu_cache_usage_sys = 0.
1677
        if num_total_cpu:  # Guard against both None and 0
1678
1679
1680
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1681
1682
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1683
1684
1685
1686
1687
1688
1689
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1690
1691
1692
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1693
        num_tokens_iter = 0
1694
1695
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1696
1697
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1698
1699
1700
1701

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1702
1703
1704
1705
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1706
1707
1708
        time_in_queue_requests: List[float] = []
        model_forward_time_requests: List[float] = []
        model_execute_time_requests: List[float] = []
1709
1710
1711
1712
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1713
        max_num_generation_tokens_requests: List[int] = []
1714
        max_tokens_requests: List[int] = []
1715
1716
        finished_reason_requests: List[str] = []

1717
        # LoRA requests
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1736
1737
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1738
        if scheduler_outputs is not None:
1739
1740
1741
1742
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1743
            num_generation_tokens_from_prefill_groups = 0
1744
1745
1746
1747
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1748
1749
1750

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1751
1752
1753
1754
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue
1755
1756
1757
1758
1759

                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1760

1761
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1762
                seq_group = scheduled_seq_group.seq_group
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
1775
                        latency = seq_group.get_last_token_latency()
1776
1777
1778
1779
1780
1781
1782
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
1783
                    latency = seq_group.get_last_token_latency()
1784
                    time_per_output_tokens_iter.append(latency)
1785
1786
1787
1788
1789
1790
1791
1792
1793
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1794
1795
1796
1797
1798
1799

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1800
                if seq_group.is_finished():
1801
                    # Latency timings
1802
1803
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1816
1817
1818
1819
1820
1821
1822
1823
1824
                    if seq_group.metrics.time_in_queue is not None:
                        time_in_queue_requests.append(
                            seq_group.metrics.time_in_queue)
                    if seq_group.metrics.model_forward_time is not None:
                        model_forward_time_requests.append(
                            seq_group.metrics.model_forward_time)
                    if seq_group.metrics.model_execute_time is not None:
                        model_execute_time_requests.append(
                            seq_group.metrics.model_execute_time * 1000)
1825
1826
1827
1828
1829
1830
1831
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1832
1833
1834
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1835
1836
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1837
1838
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1851
                actual_num_batched_tokens - num_prompt_tokens_iter +
1852
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1853
1854
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1855
1856
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
1857
1858
        if model_output and isinstance(model_output[0], SamplerOutput) and (
                model_output[0].spec_decode_worker_metrics is not None):
1859
1860
1861
1862
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1863
1864
        return Stats(
            now=now,
1865
1866
1867
1868
1869
1870
1871
1872
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1873
1874
1875
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1876
1877
1878
1879

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1880
            num_tokens_iter=num_tokens_iter,
1881
1882
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1883
            spec_decode_metrics=spec_decode_metrics,
1884
            num_preemption_iter=num_preemption_iter,
1885
1886
1887
1888

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1889
1890
1891
1892
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1893
1894
1895
            time_in_queue_requests=time_in_queue_requests,
            model_forward_time_requests=model_forward_time_requests,
            model_execute_time_requests=model_execute_time_requests,
1896
1897
1898
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1899
1900
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
1901
            n_requests=n_requests,
1902
            max_tokens_requests=max_tokens_requests,
1903
            finished_reason_requests=finished_reason_requests,
1904
1905
1906
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1907

1908
    def add_lora(self, lora_request: LoRARequest) -> bool:
1909
        return self.model_executor.add_lora(lora_request)
1910
1911

    def remove_lora(self, lora_id: int) -> bool:
1912
        return self.model_executor.remove_lora(lora_id)
1913

1914
    def list_loras(self) -> Set[int]:
1915
        return self.model_executor.list_loras()
1916

1917
1918
1919
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

1930
1931
1932
1933
1934
1935
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
    def sleep(self, level: int = 1) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.sleep(level=level)

    def wake_up(self) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.wake_up()

1946
    def check_health(self) -> None:
1947
1948
        if self.tokenizer:
            self.tokenizer.check_health()
1949
        self.model_executor.check_health()
1950
1951
1952
1953

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

1954
1955
1956
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
1957
1958
1959
        if self.tracer is None:
            return

1960
1961
1962
1963
1964
1965
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
1985
            seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
1986
                                   self.model_config.model)
1987
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
1988
                                   seq_group.request_id)
1989
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
1990
                                   seq_group.sampling_params.temperature)
1991
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
1992
                                   seq_group.sampling_params.top_p)
1993
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
1994
                                   seq_group.sampling_params.max_tokens)
1995
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
1996
                                   seq_group.sampling_params.n)
1997
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
1998
                                   seq_group.num_seqs())
1999
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
2000
2001
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
2002
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
2003
2004
2005
2006
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
2007
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
2008
2009
                                   metrics.time_in_queue)
            seq_span.set_attribute(
2010
2011
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
2012
2013
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
2014
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
2015
2016
2017
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
2018
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
2019
2020
2021
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
2022
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
2023
                    metrics.model_execute_time)
2024

2025
    def _validate_model_inputs(self, inputs: ProcessorInputs,
2026
                               lora_request: Optional[LoRARequest]):
2027
        if is_encoder_decoder_inputs(inputs):
2028
2029
            # For encoder-decoder multimodal models, the max_prompt_len
            # restricts the decoder prompt length
2030
2031
            prompt_inputs = inputs["decoder" if self.model_config.
                                   is_multimodal_model else "encoder"]
2032
        else:
2033
2034
            prompt_inputs = inputs

2035
        prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
2036
2037

        if prompt_ids is None or len(prompt_ids) == 0:
2038
            raise ValueError("Prompt cannot be empty")
2039

2040
        if self.model_config.is_multimodal_model:
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
            max_prompt_len = self.model_config.max_model_len

            if len(prompt_ids) > max_prompt_len:
                raise ValueError(
                    f"The prompt (total length {len(prompt_ids)}) is too long "
                    f"to fit into the model (context length {max_prompt_len}). "
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
2051
2052
2053
2054

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
        """Constructs logits processors based on the guided_decoding,
        logits_bias, and allowed_token_ids fields in sampling_params. Deletes
        those fields and adds the constructed logits processors to the
        logits_processors field. Returns the modified sampling params."""

        logits_processors = []
2065

2066
2067
2068
2069
2070
        if sampling_params.guided_decoding is not None:
            # Defensively copy sampling params since guided decoding logits
            # processors can have different state for each request
            sampling_params = copy.copy(sampling_params)
            guided_decoding = sampling_params.guided_decoding
2071
2072
2073
2074
2075
2076
2077
2078
2079

            logger.debug(
                "Building guided decoding logits processor in "
                "LLMEngine. Params: %s", guided_decoding)

            tokenizer = self.get_tokenizer(lora_request=lora_request)
            guided_decoding.backend = guided_decoding.backend or \
                self.decoding_config.guided_decoding_backend

2080
2081
2082
            logger.debug("Reasoning backend: %s",
                         self.decoding_config.reasoning_backend)

2083
            processor = get_local_guided_decoding_logits_processor(
2084
2085
                guided_params=guided_decoding,
                tokenizer=tokenizer,
2086
2087
2088
                model_config=self.model_config,
                reasoning_backend=self.decoding_config.reasoning_backend,
            )
2089
2090
2091
2092
2093
2094
2095
2096
2097
            if processor:
                logits_processors.append(processor)

            # Unset so this doesn't get passed down to the model
            sampling_params.guided_decoding = None

        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

2098
            processors = get_openai_logits_processors(
2099
2100
2101
2102
2103
2104
2105
2106
2107
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

2108
2109
2110
2111
2112
2113
        if len(sampling_params.bad_words) > 0:
            tokenizer = self.get_tokenizer(lora_request)
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

2114
2115
2116
2117
2118
2119
2120
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params
2121
2122


2123
2124
2125
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
    LLMEngine = V1LLMEngine  # type: ignore