llm_engine.py 85 KB
Newer Older
Antoni Baum's avatar
Antoni Baum committed
1
import time
2
from collections import Counter as collectionsCounter
3
from collections import deque
4
from contextlib import contextmanager
5
from dataclasses import dataclass
6
from functools import partial
7
8
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
                    Iterable, List, Mapping, NamedTuple, Optional)
9
from typing import Sequence as GenericSequence
10
from typing import Set, Type, Union, cast, overload
11

12
import torch
13
from typing_extensions import TypeVar
14

15
import vllm.envs as envs
16
17
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
18
                         ObservabilityConfig, ParallelConfig,
19
                         PromptAdapterConfig, SchedulerConfig,
20
                         SpeculativeConfig)
21
22
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.engine.arg_utils import EngineArgs
24
from vllm.engine.metrics_types import StatLoggerBase, Stats
25
26
27
28
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
29
from vllm.entrypoints.openai.logits_processors import get_logits_processors
30
from vllm.executor.executor_base import ExecutorBase
31
from vllm.executor.gpu_executor import GPUExecutor
32
from vllm.executor.ray_utils import initialize_ray_cluster
33
34
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
                         EncoderDecoderInputs, InputRegistry, PromptType)
35
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.logger import init_logger
37
from vllm.lora.request import LoRARequest
38
39
from vllm.model_executor.guided_decoding import (
    get_local_guided_decoding_logits_processor)
40
from vllm.model_executor.layers.sampler import SamplerOutput
41
42
43
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
44
from vllm.prompt_adapter.request import PromptAdapterRequest
45
from vllm.sampling_params import RequestOutputKind, SamplingParams
46
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
47
48
49
50
                           ParallelSampleSequenceGroup, Sequence,
                           SequenceGroup, SequenceGroupBase,
                           SequenceGroupMetadata, SequenceGroupOutput,
                           SequenceStatus)
51
52
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
53
from vllm.transformers_utils.config import try_get_generation_config
54
from vllm.transformers_utils.detokenizer import Detokenizer
55
from vllm.transformers_utils.tokenizer import AnyTokenizer
56
from vllm.transformers_utils.tokenizer_group import (
57
    BaseTokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
58
59
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
60
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
61
from vllm.version import __version__ as VLLM_VERSION
62
63

logger = init_logger(__name__)
64
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
65

66

67
68
69
70
71
72
73
74
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
    config = try_get_generation_config(
        model_config.model,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.revision,
    )

    if config is None:
75
76
        return {}

77
78
    return config.to_diff_dict()

79

80
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
81
82
83
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)


84
85
86
87
88
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
89
90
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
91
92


93
94
95
96
97
98
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
99
100
101
102
103
104
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
105
106
107
    skip: List[int]


108
class SchedulerContext:
109

110
    def __init__(self, multi_step_stream_outputs: bool = False):
111
112
113
114
115
116
117
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
                                         EmbeddingRequestOutput]] = []
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

118
119
        self.multi_step_stream_outputs: bool = multi_step_stream_outputs

120
121
122
    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
123
124
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
125
126
127
128
129
130
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
131
                       is_first_step_output=is_first_step_output,
132
                       skip=[]))
133
134


135
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
136
    """An LLM engine that receives requests and generates texts.
137

Woosuk Kwon's avatar
Woosuk Kwon committed
138
    This is the main class for the vLLM engine. It receives requests
139
140
141
142
143
144
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

145
146
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
147

148
149
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
    :ref:`engine_args`)
150
151
152
153
154
155
156

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
157
        device_config: The configuration related to the device.
158
159
160
        lora_config (Optional): The configuration related to serving multi-LoRA.
        speculative_config (Optional): The configuration related to speculative
            decoding.
161
162
        executor_class: The model executor class for managing distributed
            execution.
163
        prompt_adapter_config (Optional): The configuration related to serving
164
            prompt adapters.
165
        log_stats: Whether to log statistics.
166
        usage_context: Specified entry point, used for usage info collection.
167
    """
168

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

194
        return cast(_O, output)
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

220
221
222
223
224
225
    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
226
        device_config: DeviceConfig,
227
        load_config: LoadConfig,
228
        lora_config: Optional[LoRAConfig],
229
        speculative_config: Optional[SpeculativeConfig],
230
        decoding_config: Optional[DecodingConfig],
231
        observability_config: Optional[ObservabilityConfig],
232
        prompt_adapter_config: Optional[PromptAdapterConfig],
233
        executor_class: Type[ExecutorBase],
234
        log_stats: bool,
yhu422's avatar
yhu422 committed
235
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
236
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
237
        input_registry: InputRegistry = INPUT_REGISTRY,
238
        use_cached_outputs: bool = False,
239
240
    ) -> None:
        logger.info(
241
242
243
            "Initializing an LLM engine (v%s) with config: "
            "model=%r, speculative_config=%r, tokenizer=%r, "
            "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
244
            "override_neuron_config=%s, "
245
            "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
246
247
            "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
            "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
248
            "pipeline_parallel_size=%d, "
249
250
            "disable_custom_all_reduce=%s, quantization=%s, "
            "enforce_eager=%s, kv_cache_dtype=%s, "
251
            "quantization_param_path=%s, device_config=%s, "
252
            "decoding_config=%r, observability_config=%r, "
253
            "seed=%d, served_model_name=%s, "
254
255
256
            "num_scheduler_steps=%d, chunked_prefill_enabled=%s "
            "multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
            "use_async_output_proc=%s, use_cached_outputs=%s, "
257
            "chat_template_text_format=%s, mm_processor_kwargs=%s)",
258
            VLLM_VERSION,
259
260
261
262
263
264
            model_config.model,
            speculative_config,
            model_config.tokenizer,
            model_config.skip_tokenizer_init,
            model_config.tokenizer_mode,
            model_config.revision,
265
            model_config.override_neuron_config,
266
            model_config.rope_scaling,
267
            model_config.rope_theta,
268
269
270
271
272
273
274
            model_config.tokenizer_revision,
            model_config.trust_remote_code,
            model_config.dtype,
            model_config.max_model_len,
            load_config.download_dir,
            load_config.load_format,
            parallel_config.tensor_parallel_size,
275
            parallel_config.pipeline_parallel_size,
276
277
278
279
280
281
282
            parallel_config.disable_custom_all_reduce,
            model_config.quantization,
            model_config.enforce_eager,
            cache_config.cache_dtype,
            model_config.quantization_param_path,
            device_config.device,
            decoding_config,
283
            observability_config,
284
            model_config.seed,
285
            model_config.served_model_name,
286
            scheduler_config.num_scheduler_steps,
287
            scheduler_config.chunked_prefill_enabled,
288
            scheduler_config.multi_step_stream_outputs,
289
            cache_config.enable_prefix_caching,
290
            model_config.use_async_output_proc,
291
            use_cached_outputs,
292
            model_config.chat_template_text_format,
293
            model_config.mm_processor_kwargs,
294
        )
295
296
297
        # TODO(woosuk): Print more configs in debug mode.
        self.model_config = model_config
        self.cache_config = cache_config
298
        self.lora_config = lora_config
299
300
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
301
        self.device_config = device_config
302
        self.speculative_config = speculative_config
303
        self.load_config = load_config
304
        self.decoding_config = decoding_config or DecodingConfig()
305
        self.prompt_adapter_config = prompt_adapter_config
306
307
        self.observability_config = observability_config or ObservabilityConfig(
        )
308
        self.log_stats = log_stats
309
        self.use_cached_outputs = use_cached_outputs
310

311
        if not self.model_config.skip_tokenizer_init:
312
            self.tokenizer = self._init_tokenizer()
313
            self.detokenizer = Detokenizer(self.tokenizer)
314
            tokenizer_group = self.get_tokenizer_group()
315
316
        else:
            self.tokenizer = None
317
            self.detokenizer = None
318
319
320
321
322
323
324
325
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
326

327
        self.seq_counter = Counter()
328
329
        self.generation_config_fields = _load_generation_config_dict(
            model_config)
330

331
332
333
        self.input_preprocessor = InputPreprocessor(model_config,
                                                    self.tokenizer)

334
335
336
        self.input_registry = input_registry
        self.input_processor = input_registry.create_input_processor(
            model_config)
337

338
339
340
341
342
343
344
345
        self.model_executor = executor_class(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
346
            load_config=load_config,
347
            prompt_adapter_config=prompt_adapter_config,
348
            observability_config=self.observability_config,
349
        )
350

351
        if self.model_config.task != "embedding":
352
            self._initialize_kv_caches()
353

yhu422's avatar
yhu422 committed
354
355
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
356
357
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
            usage_message.report_usage(
                get_architecture_class_name(model_config),
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
                    str(model_config.dtype),
                    "tensor_parallel_size":
                    parallel_config.tensor_parallel_size,
                    "block_size":
                    cache_config.block_size,
                    "gpu_memory_utilization":
                    cache_config.gpu_memory_utilization,

                    # Quantization
                    "quantization":
                    model_config.quantization,
                    "kv_cache_dtype":
376
                    str(cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
377
378
379
380

                    # Feature flags
                    "enable_lora":
                    bool(lora_config),
381
382
                    "enable_prompt_adapter":
                    bool(prompt_adapter_config),
yhu422's avatar
yhu422 committed
383
384
385
386
387
388
389
390
                    "enable_prefix_caching":
                    cache_config.enable_prefix_caching,
                    "enforce_eager":
                    model_config.enforce_eager,
                    "disable_custom_all_reduce":
                    parallel_config.disable_custom_all_reduce,
                })

391
392
393
394
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
395

396
397
398
399
400
401
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
402
403
            SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
                             multi_step_stream_outputs)
404
405
406
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

407
408
409
410
411
412
413
414
415
416
        if model_config.use_async_output_proc:
            process_model_outputs = weak_bind(self._process_model_outputs)

            self.async_callbacks = [
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
        else:
            self.async_callbacks = []
417
418
419

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
420
        self.process_request_outputs_callback: Optional[Callable] = None
421

422
        # Create the scheduler.
423
424
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
425
        self.scheduler = [
426
427
428
            Scheduler(
                scheduler_config, cache_config, lora_config,
                parallel_config.pipeline_parallel_size,
429
                self.async_callbacks[v_id]
430
                if model_config.use_async_output_proc else None)
431
            for v_id in range(parallel_config.pipeline_parallel_size)
432
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
433

434
435
        # Metric Logging.
        if self.log_stats:
436
437
438
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
439
440
441
442
443
444
445
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

446
447
448
449
450
451
452
453
454
455
456
457
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        labels=dict(model_name=model_config.served_model_name),
                        max_model_len=self.model_config.max_model_len),
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
458

459
460
461
462
463
464
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

465
466
467
468
469
470
471
472
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
473
                get_tokenizer_for_seq,
474
475
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
476
                    get_tokenizer_for_seq,
477
478
479
                ),
            ))

480
481
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

482
483
484
485
486
487
488
489
490
491
492
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
493
494
495
496
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
497
498
499
500
501
502
503
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

504
    @classmethod
505
506
    def _get_executor_cls(cls,
                          engine_config: EngineConfig) -> Type[ExecutorBase]:
507
508
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
509
        # Initialize the cluster and specify the executor class.
510
511
512
513
514
515
516
517
518
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
519
520
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
521
        elif engine_config.device_config.device_type == "tpu":
522
523
524
525
526
527
528
529
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutor
                executor_class = RayTPUExecutor
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutor
                executor_class = TPUExecutor
530
        elif engine_config.device_config.device_type == "cpu":
531
532
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
533
534
535
        elif engine_config.device_config.device_type == "openvino":
            from vllm.executor.openvino_executor import OpenVINOExecutor
            executor_class = OpenVINOExecutor
536
537
538
539
540
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutor
                executor_class = RayXPUExecutor
541
542
543
544
545
546
547
            elif distributed_executor_backend == "mp":
                # FIXME(kunshang):
                # spawn needs calling `if __name__ == '__main__':``
                # fork is not supported for xpu start new process.
                logger.error(
                    "Both start methods (spawn and fork) have issue "
                    "on XPU if you use mp backend, Please try ray instead.")
548
549
550
            else:
                from vllm.executor.xpu_executor import XPUExecutor
                executor_class = XPUExecutor
551
        elif distributed_executor_backend == "ray":
552
            initialize_ray_cluster(engine_config.parallel_config)
553
554
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
555
556
557
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutor)
558
559
560
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
561
            executor_class = MultiprocessingGPUExecutor
562
563
564
        else:
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor
565
566
567
568
569
570
571
572
573
574
575
576
577
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()
        executor_class = cls._get_executor_cls(engine_config)
578
        # Create the LLM engine.
yhu422's avatar
yhu422 committed
579
        engine = cls(
580
            **engine_config.to_dict(),
yhu422's avatar
yhu422 committed
581
582
583
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
584
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
585
        )
586

587
        return engine
588

589
590
591
592
593
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

594
595
596
597
598
599
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

600
    def get_tokenizer_group(
601
602
603
604
605
606
        self,
        group_type: Type[_G] = BaseTokenizerGroup,
    ) -> _G:
        tokenizer_group = self.tokenizer

        if tokenizer_group is None:
607
608
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
609
610
611
612
        if not isinstance(tokenizer_group, group_type):
            raise TypeError("Invalid type of tokenizer group. "
                            f"Expected type: {group_type}, but "
                            f"found type: {type(tokenizer_group)}")
613

614
        return tokenizer_group
615

616
    def get_tokenizer(
617
618
619
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
620
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
621

622
623
624
625
626
627
    def _init_tokenizer(self) -> BaseTokenizerGroup:
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
            parallel_config=self.parallel_config,
            enable_lora=bool(self.lora_config))
628

629
630
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
631
        self.cache_config.verify_with_parallel_config(self.parallel_config)
632
633
634
635
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
636
637
638
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
639

640
641
642
    def _add_processed_request(
        self,
        request_id: str,
643
        processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
644
645
646
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
647
        prompt_adapter_request: Optional[PromptAdapterRequest],
648
        trace_headers: Optional[Mapping[str, str]] = None,
649
        priority: int = 0,
650
651
652
653
    ) -> SequenceGroup:
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
654
        self._validate_model_inputs(processed_inputs)
655
656
657
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
658
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
659
660

        seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
661
                       lora_request, prompt_adapter_request)
662

663
664
665
666
667
668
669
670
671
672
        encoder_seq = None
        if 'encoder_prompt_token_ids' in processed_inputs:
            encoder_seq = Sequence(seq_id,
                                   processed_inputs,
                                   block_size,
                                   eos_token_id,
                                   lora_request,
                                   prompt_adapter_request,
                                   from_decoder_prompt=False)

673
674
675
676
677
678
679
680
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
681
                trace_headers=trace_headers,
682
                prompt_adapter_request=prompt_adapter_request,
683
684
                encoder_seq=encoder_seq,
                priority=priority)
685
686
687
688
689
690
691
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
692
                prompt_adapter_request=prompt_adapter_request,
693
694
                encoder_seq=encoder_seq,
                priority=priority)
695
696
697
698
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

699
700
701
702
703
704
705
706
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

707
708
        return seq_group

709
710
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
711

712
    @overload  # DEPRECATED
713
714
715
    def add_request(
        self,
        request_id: str,
716
717
        *,
        inputs: PromptType,
718
        params: Union[SamplingParams, PoolingParams],
719
        arrival_time: Optional[float] = None,
720
        lora_request: Optional[LoRARequest] = None,
721
        trace_headers: Optional[Mapping[str, str]] = None,
722
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
723
        priority: int = 0,
724
    ) -> Optional[SequenceGroup]:
725
726
727
728
729
730
731
732
733
734
735
736
737
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
738
    ) -> Optional[SequenceGroup]:
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    def add_request(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
            priority: int = 0,
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
757
    ) -> Optional[SequenceGroup]:
Zhuohan Li's avatar
Zhuohan Li committed
758
        """Add a request to the engine's request pool.
759
760

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
761
        scheduler as `engine.step()` is called. The exact scheduling policy is
762
763
764
765
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
766
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
767
768
769
770
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
771
            arrival_time: The arrival time of the request. If None, we use
772
                the current monotonic time.
773
            trace_headers: OpenTelemetry trace headers.
774
775
            priority: The priority of the request.
                Only applicable with priority scheduling.
776
777
778
779

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
780
            - Create `n` number of :class:`~vllm.Sequence` objects.
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
800
        """
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816

        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                prompt=prompt,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
                inputs=inputs,
            )
            return None

817
818
819
820
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

821
822
823
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
824

825
        if priority != 0 and not self.scheduler_config.policy == "priority":
826
827
828
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

829
        if arrival_time is None:
830
            arrival_time = time.time()
831

832
        preprocessed_inputs = self.input_preprocessor.preprocess(
833
            prompt,
834
835
            request_id=request_id,
            lora_request=lora_request,
836
837
            prompt_adapter_request=prompt_adapter_request,
        )
838
        processed_inputs = self.input_processor(preprocessed_inputs)
839

840
841
842
843
844
845
846
        # This is a bit of a hack - copy the mm_processor_kwargs that were
        # used in the input processor to the processed output, since these
        # kwargs are presumed to be immutable and the values should be aligned
        # between the input processor (here) and the input mapper.
        processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
            "mm_processor_kwargs")

847
        return self._add_processed_request(
848
849
850
851
852
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
853
            prompt_adapter_request=prompt_adapter_request,
854
            trace_headers=trace_headers,
855
            priority=priority,
856
        )
857
858
859
860
861
862

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
863
864
        arrival_time: float,
        lora_request: Optional[LoRARequest],
865
        trace_headers: Optional[Mapping[str, str]] = None,
866
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
867
        encoder_seq: Optional[Sequence] = None,
868
        priority: int = 0,
869
870
871
872
873
874
875
876
877
878
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

879
880
881
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

882
883
884
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
885

886
        sampling_params.update_from_generation_config(
887
            self.generation_config_fields, seq.eos_token_id)
888

889
        # Create the sequence group.
890
891
892
893
894
895
896
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
897
            prompt_adapter_request=prompt_adapter_request,
898
899
            encoder_seq=encoder_seq,
            priority=priority)
900

901
902
903
904
905
906
907
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
908
909
        arrival_time: float,
        lora_request: Optional[LoRARequest],
910
        prompt_adapter_request: Optional[PromptAdapterRequest],
911
        encoder_seq: Optional[Sequence] = None,
912
        priority: int = 0,
913
914
915
916
917
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
918
919
920
921
922
923
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
924
            prompt_adapter_request=prompt_adapter_request,
925
926
            encoder_seq=encoder_seq,
            priority=priority)
927
        return seq_group
928

Antoni Baum's avatar
Antoni Baum committed
929
930
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
931
932

        Args:
Antoni Baum's avatar
Antoni Baum committed
933
            request_id: The ID(s) of the request to abort.
934
935
936
937
938
939
940
941
942
943
944

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
945
        """
946
947
        for scheduler in self.scheduler:
            scheduler.abort_seq_group(request_id)
948

949
950
951
952
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

953
954
955
956
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

957
958
959
960
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

961
962
963
964
965
966
967
968
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

969
    def get_num_unfinished_requests(self) -> int:
970
        """Gets the number of unfinished requests."""
971
972
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
973

974
    def has_unfinished_requests(self) -> bool:
975
        """Returns True if there are unfinished requests."""
976
977
978
979
980
981
982
983
984
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
985

986
    @staticmethod
987
988
989
990
991
992
993
994
995
996
997
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
        outputs: List[EmbeddingSequenceGroupOutput],
    ) -> None:
        seq_group.embeddings = outputs[0].embeddings

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
    def _update_num_computed_tokens_for_multi_step_prefill(
            self, seq_group: SequenceGroup,
            seq_group_meta: SequenceGroupMetadata,
            is_first_step_output: Optional[bool]):
        """
        This function updates num_computed_tokens for prompt sequences
        when Multi-Step is enabled.

        seq_group: SequenceGroup to update the num_computed_tokens for. 
        seq_group_meta: Metadata of the given SequenceGroup.
        is_first_step_output: Optional[bool] - 
            When available, is_first_step_output indicates if the appended
            output token is the output of the first-step in multi-step.
            A value of None indicates that outputs from all steps in
            in multi-step are submitted in a single burst.
        """

        assert self.scheduler_config.is_multi_step

        if not seq_group_meta.is_prompt:
            # num_computed_token updates for multi-step decodes happen after
            # the tokens are appended to the sequence.
            return

        do_update: bool = False
        if self.scheduler_config.chunked_prefill_enabled:
            # In multi-step + chunked-prefill case, the prompt sequences
            # that are scheduled are fully processed in the first step.
            do_update = is_first_step_output is None or is_first_step_output
        else:
            # Normal multi-step decoding case. In this case prompt-sequences
            # are actually single-stepped. Always update in this case.
            assert seq_group.state.num_steps == 1
            do_update = True

        if do_update:
            seq_group.update_num_computed_tokens(
                seq_group_meta.token_chunk_size)

1037
1038
1039
1040
1041
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
1042

1043
1044
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
1045
        """
1046

1047
        now = time.time()
1048

1049
        if len(ctx.output_queue) == 0:
1050
1051
            return None

1052
        # Get pending async postprocessor
1053
1054
1055
1056
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1057
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
1058
1059
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1060
1061
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
1062
1063
1064
1065
1066

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

1067
        has_multiple_outputs: bool = len(outputs) > 1
1068
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
1069
1070
1071
1072
1073
        if has_multiple_outputs:
            assert self.scheduler_config.is_multi_step or \
                     self.speculative_config
            # Organize outputs by [step][sequence group] instead of
            # [sequence group][step].
1074
1075
            outputs_by_sequence_group = create_output_by_sequence_group(
                outputs, num_seq_groups=len(seq_group_metadata_list))
1076
1077
1078
            # We have outputs for multiple steps submitted in a single burst,
            # so invalidate is_first_step_output.
            is_first_step_output = None
1079
1080
1081
        else:
            outputs_by_sequence_group = outputs

1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

1099
        finished_before: List[int] = []
1100
        finished_now: List[int] = []
1101
1102
1103
1104
1105
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
1106
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1107

1108
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
1109
1110
1111
1112
1113

            if seq_group.is_finished():
                finished_before.append(i)
                continue

1114
            output: List[SequenceGroupOutput]
1115
            if has_multiple_outputs:
1116
1117
1118
1119
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

1120
1121
1122
1123
1124
1125
1126
            if not is_async:
                if self.scheduler_config.is_multi_step:
                    # Updates happen only if the sequence is prefill
                    self._update_num_computed_tokens_for_multi_step_prefill(
                        seq_group, seq_group_meta, is_first_step_output)
                else:
                    seq_group.update_num_computed_tokens(
1127
                        seq_group_meta.token_chunk_size or 0)
1128
1129
1130

            if outputs:
                for o in outputs:
1131
1132
1133
1134
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
1135
                                o.model_forward_time or 0)
1136
1137
1138
1139
1140
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
1141
                                o.model_execute_time or 0)
1142
1143
1144
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1145

1146
            if self.model_config.task == "embedding":
1147
                self._process_sequence_group_outputs(seq_group, output)
1148
1149
1150
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
1151
                    self.output_processor.process_outputs(
1152
                        seq_group, output, is_async)
1153

1154
1155
            if seq_group.is_finished():
                finished_now.append(i)
1156

1157
1158
1159
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1160

1161
1162
            seq_group = scheduled_seq_group.seq_group
            seq_group.maybe_set_first_token_time(now)
1163
            request_output = RequestOutputFactory.create(
1164
1165
1166
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1167
1168
            if request_output:
                ctx.request_outputs.append(request_output)
1169

1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1182
1183
1184
1185
1186
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

1187
1188
        # For multi-step without streaming, don't create outputs each iteration
        if not is_last_step and not ctx.multi_step_stream_outputs:
1189
1190
1191
1192
            # Immediately process request outputs here (if callback is given)
            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
1193
                ctx.request_outputs.clear()
1194
1195
1196
            return

        # Create the outputs
1197
1198
        for i in indices:
            if i in skip or i in finished_before or i in finished_now:
1199
1200
                continue  # Avoids double processing

1201
1202
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1203
            seq_group = scheduled_seq_group.seq_group
1204
            seq_group.maybe_set_first_token_time(now)
1205
            request_output = RequestOutputFactory.create(
1206
1207
1208
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1209
            if request_output:
1210
                ctx.request_outputs.append(request_output)
1211

1212
1213
1214
1215
1216
1217
1218
1219
        # For multi-step with streaming, create outputs each iteration
        if not is_last_step and ctx.multi_step_stream_outputs:
            # Immediately process request outputs here (if callback is given)
            if self.process_request_outputs_callback is not None:
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1220
        for seq_group in scheduler_outputs.ignored_seq_groups:
1221
1222
1223
1224
1225
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1226
            request_output = RequestOutputFactory.create(
1227
1228
1229
1230
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
1231
1232
            if request_output:
                ctx.request_outputs.append(request_output)
1233

1234
1235
1236
1237
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1238
            ctx.request_outputs.clear()
1239

1240
1241
1242
1243
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1244
            # Log stats.
1245
1246
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1247
1248

            # Tracing
1249
            self.do_tracing(scheduler_outputs, finished_before)
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267

        return None

    def _advance_to_next_step(
            self, output: List[SamplerOutput],
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1268
1269
1270
1271
1272
            if self.scheduler_config.is_multi_step:
                # Updates happen only if the sequence is prefill
                self._update_num_computed_tokens_for_multi_step_prefill(
                    seq_group, seq_group_metadata,
                    seq_group.state.num_steps == 1)
1273
            else:
1274
1275
1276
1277
                token_chunk_size = (seq_group_metadata.token_chunk_size
                                    if seq_group_metadata.token_chunk_size
                                    is not None else 0)
                seq_group.update_num_computed_tokens(token_chunk_size)
1278

1279
1280
1281
            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1282
                    " (i.e sampling_params.n == 1)")
1283
1284
1285
1286
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1287
1288
1289
1290
1291
1292
1293
1294
1295

                if self.scheduler_config.is_multi_step:
                    is_prefill_append = seq.data.get_num_uncomputed_tokens(
                    ) == 0
                    seq.append_token_id(sample.output_token, sample.logprobs)
                    if not is_prefill_append:
                        seq_group.update_num_computed_tokens(1)
                else:
                    seq.append_token_id(sample.output_token, sample.logprobs)
1296

1297
    def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1298
1299
        """Performs one decoding iteration and returns newly generated results.

1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

1315
            - Step 2: Calls the distributed executor to execute the model.
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
1337
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
1338
1339
1340
1341
1342
1343
1344
1345
1346
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
1347
        """
1348
1349
1350
1351
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1352

1353
        # For llm_engine, there is no pipeline parallel support, so the engine
1354
        # used is always 0.
1355
1356
        virtual_engine = 0

1357
1358
        # These are cached outputs from previous iterations. None if on first
        # iteration
1359
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1360
1361
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1362
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1363

1364
1365
        ctx = self.scheduler_contexts[virtual_engine]

1366
1367
1368
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1369
1370
1371
1372
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
1373
            # Schedule iteration
1374
            (seq_group_metadata_list, scheduler_outputs,
1375
1376
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1377

1378
1379
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1380

1381
1382
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1383
                self._process_model_outputs(ctx=ctx)
1384

1385
1386
1387
1388
1389
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
1390
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
1391
                    allow_async_output_proc)
1392
1393
1394

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1395

1396
        if not scheduler_outputs.is_empty():
1397
            finished_requests_ids = self.scheduler[
1398
                virtual_engine].get_and_reset_finished_requests_ids()
1399
1400
1401
1402
1403
1404

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1405
                self._get_last_sampled_token_ids(virtual_engine)
1406

1407
            execute_model_req = ExecuteModelRequest(
1408
1409
1410
1411
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1412
1413
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1414
1415
1416
1417
1418
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1419
            if allow_async_output_proc:
1420
1421
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1422

1423
            outputs = self.model_executor.execute_model(
1424
                execute_model_req=execute_model_req)
1425

1426
            # We need to do this here so that last step's sampled_token_ids can
1427
1428
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
1429
                self._update_cached_scheduler_output(virtual_engine, outputs)
1430
        else:
1431
1432
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1433
1434
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1435
            # No outputs in this case
1436
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1437

1438
1439
1440
1441
1442
1443
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
1444
            # clear the cache if we have finished all the steps.
1445
1446
1447
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()

1448
1449
1450
1451
1452
1453
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1454
            # Add results to the output_queue
1455
1456
1457
1458
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1459
1460
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1461
1462
1463

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1464
                    "Async postprocessor expects only a single output set")
1465

1466
                self._advance_to_next_step(
1467
                    outputs[0], seq_group_metadata_list,
1468
                    scheduler_outputs.scheduled_seq_groups)
1469

1470
            # Check if need to run the usual non-async path
1471
            if not allow_async_output_proc:
1472
                self._process_model_outputs(ctx=ctx)
1473

1474
                # Log stats.
1475
                self.do_log_stats(scheduler_outputs, outputs)
1476

1477
1478
1479
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1480
            # Multi-step case
1481
            return ctx.request_outputs
1482

1483
        if not self.has_unfinished_requests():
1484
1485
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1486
                self._process_model_outputs(ctx=ctx)
1487
            assert len(ctx.output_queue) == 0
1488

1489
1490
1491
1492
1493
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1494
            logger.debug("Stopping remote worker execution loop.")
1495
1496
            self.model_executor.stop_remote_worker_execution_loop()

1497
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1498

1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
            raise AssertionError(("All running sequence groups should "
                                  "have the same remaining steps."))

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1522
1523
1524
1525
1526
1527
1528
1529
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None

1555
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1556
1557
1558
1559
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1560
1561
1562
1563
1564
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1565
1566
1567
1568
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1569
1570
1571
1572
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1573
1574
1575
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1576
1577
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1578
1579
        """Forced log when no requests active."""
        if self.log_stats:
1580
            stats = self._get_stats(scheduler_outputs, model_output,
1581
                                    finished_before, skip)
1582
            for logger in self.stat_loggers.values():
1583
                logger.log(stats)
1584

1585
1586
1587
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1588
1589
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1590
1591
1592
1593
1594
1595
1596
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1597
1598
1599
1600
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1601
        """
1602
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1603

1604
1605
        # System State
        #   Scheduler State
1606
1607
1608
1609
1610
1611
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1612
1613

        # KV Cache Usage in %
1614
        num_total_gpu = self.cache_config.num_gpu_blocks
1615
        gpu_cache_usage_sys = 0.
1616
        if num_total_gpu:  # Guard against both None and 0
1617
1618
1619
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1620
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1621

1622
        num_total_cpu = self.cache_config.num_cpu_blocks
1623
        cpu_cache_usage_sys = 0.
1624
        if num_total_cpu:  # Guard against both None and 0
1625
1626
1627
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1628
1629
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1630
1631
1632
1633
1634
1635
1636
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1637
1638
1639
1640
1641
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1642
1643
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
        finished_reason_requests: List[str] = []

1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
        # Lora requests
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1673
1674
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1675
        if scheduler_outputs is not None:
1676
1677
1678
1679
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1680
            num_generation_tokens_from_prefill_groups = 0.
1681
1682
1683
1684
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1685
1686
1687

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1688
1689
1690
1691
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue
1692
1693
1694
1695
1696

                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1697

1698
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1699
                seq_group = scheduled_seq_group.seq_group
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
                        latency = seq_group.get_last_latency(now)
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
                    latency = seq_group.get_last_latency(now)
                    time_per_output_tokens_iter.append(latency)
1722
1723
1724
1725
1726
1727
1728
1729
1730
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1731
1732
1733
1734
1735
1736

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1737
                if seq_group.is_finished():
1738
                    # Latency timings
1739
1740
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
1741
1742
1743
1744
1745
1746
1747
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
1748
1749
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1762
                actual_num_batched_tokens - num_prompt_tokens_iter +
1763
                num_generation_tokens_from_prefill_groups)
1764

1765
1766
1767
1768
1769
1770
1771
1772
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
        if model_output and (model_output[0].spec_decode_worker_metrics
                             is not None):
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1773
1774
        return Stats(
            now=now,
1775
1776
1777
1778
1779
1780
1781
1782
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1783
1784
1785
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1786
1787
1788
1789
1790
1791

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1792
            spec_decode_metrics=spec_decode_metrics,
1793
            num_preemption_iter=num_preemption_iter,
1794
1795
1796
1797
1798
1799
1800
1801
1802

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
            n_requests=n_requests,
            finished_reason_requests=finished_reason_requests,
1803
1804
1805
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1806

1807
    def add_lora(self, lora_request: LoRARequest) -> bool:
1808
        return self.model_executor.add_lora(lora_request)
1809
1810

    def remove_lora(self, lora_id: int) -> bool:
1811
        return self.model_executor.remove_lora(lora_id)
1812

1813
    def list_loras(self) -> Set[int]:
1814
        return self.model_executor.list_loras()
1815

1816
1817
1818
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

1829
    def check_health(self) -> None:
1830
1831
        if self.tokenizer:
            self.tokenizer.check_health()
1832
        self.model_executor.check_health()
1833

1834
    def start_profile(self) -> None:
1835
1836
        # using type instead of isinstance to check to avoid capturing
        # inherited classes (MultiprocessingGPUExecutor)
1837
        if type(self.model_executor) == GPUExecutor:  # noqa: E721
1838
1839
1840
            self.model_executor.start_profile()
        else:
            self.model_executor._run_workers("start_profile")
1841
1842

    def stop_profile(self) -> None:
1843
1844
        # using type instead of isinstance to check to avoid capturing
        # inherited classes (MultiprocessingGPUExecutor)
1845
        if type(self.model_executor) == GPUExecutor:  # noqa: E721
1846
1847
1848
            self.model_executor.stop_profile()
        else:
            self.model_executor._run_workers("stop_profile")
1849

1850
1851
1852
    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

1853
1854
1855
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
1856
1857
1858
        if self.tracer is None:
            return

1859
1860
1861
1862
1863
1864
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
            # attribute names are based on
            # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
            seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
                                   self.model_config.model)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
                                   seq_group.request_id)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
                                   seq_group.sampling_params.temperature)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
                                   seq_group.sampling_params.top_p)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
                                   seq_group.sampling_params.max_tokens)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
                                   seq_group.sampling_params.n)
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
                                   seq_group.num_seqs())
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
                SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
                                   metrics.time_in_queue)
            seq_span.set_attribute(
                SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER,
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD,
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
                    metrics.model_execute_time)
1925
1926

    def is_encoder_decoder_model(self):
1927
        return self.input_preprocessor.is_encoder_decoder_model()
1928

1929
1930
    def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
                                                   EncoderDecoderInputs]):
1931
1932
1933
1934
1935
        if self.model_config.is_multimodal_model:
            # For encoder-decoder multimodal models, the max_prompt_len
            # restricts the decoder prompt length
            prompt_ids = inputs.get("prompt_token_ids")
        elif self.is_encoder_decoder_model():
1936
1937
1938
1939
1940
            prompt_ids = inputs.get("encoder_prompt_token_ids")
        else:
            prompt_ids = inputs.get("prompt_token_ids")

        if prompt_ids is None or len(prompt_ids) == 0:
1941
            raise ValueError("Prompt cannot be empty")
1942

1943
        if self.model_config.is_multimodal_model:
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
            max_prompt_len = self.model_config.max_model_len

            if len(prompt_ids) > max_prompt_len:
                raise ValueError(
                    f"The prompt (total length {len(prompt_ids)}) is too long "
                    f"to fit into the model (context length {max_prompt_len}). "
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
1954
1955
1956
1957

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
        """Constructs logits processors based on the guided_decoding,
        logits_bias, and allowed_token_ids fields in sampling_params. Deletes
        those fields and adds the constructed logits processors to the
        logits_processors field. Returns the modified sampling params."""

        logits_processors = []
        if (guided_decoding := sampling_params.guided_decoding) is not None:

            logger.debug(
                "Building guided decoding logits processor in "
                "LLMEngine. Params: %s", guided_decoding)

            tokenizer = self.get_tokenizer(lora_request=lora_request)
            guided_decoding.backend = guided_decoding.backend or \
                self.decoding_config.guided_decoding_backend

            processor = get_local_guided_decoding_logits_processor(
                guided_params=guided_decoding, tokenizer=tokenizer)
            if processor:
                logits_processors.append(processor)

            # Unset so this doesn't get passed down to the model
            sampling_params.guided_decoding = None

        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

            processors = get_logits_processors(
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params