llm_engine.py 82 KB
Newer Older
1
import functools
Antoni Baum's avatar
Antoni Baum committed
2
import time
3
from collections import deque
4
from contextlib import contextmanager
5
from dataclasses import dataclass, field
6
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
7
                    Mapping, Optional)
8
from typing import Sequence as GenericSequence
9
from typing import Set, Tuple, Type, Union
10

11
import torch
12
from typing_extensions import TypeVar, assert_never
13

14
import vllm.envs as envs
15
16
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
17
                         ObservabilityConfig, ParallelConfig,
18
                         PromptAdapterConfig, SchedulerConfig,
19
                         SpeculativeConfig)
20
21
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.engine.arg_utils import EngineArgs
23
from vllm.engine.metrics_types import StatLoggerBase, Stats
24
25
26
27
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
28
from vllm.executor.executor_base import ExecutorBase
29
from vllm.executor.ray_utils import initialize_ray_cluster
30
31
32
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
                         InputRegistry, LLMInputs, PromptInputs,
                         SingletonPromptInputs)
33
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.logger import init_logger
35
from vllm.lora.request import LoRARequest
36
from vllm.model_executor.layers.sampler import SamplerOutput
37
from vllm.multimodal import MultiModalDataDict
38
39
40
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
41
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
42
from vllm.sampling_params import SamplingParams
43
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
44
45
                           Sequence, SequenceGroup, SequenceGroupMetadata,
                           SequenceStatus)
46
47
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
48
from vllm.transformers_utils.config import try_get_generation_config
49
from vllm.transformers_utils.detokenizer import Detokenizer
50
from vllm.transformers_utils.tokenizer import AnyTokenizer
51
from vllm.transformers_utils.tokenizer_group import (
52
    BaseTokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
53
54
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
55
from vllm.utils import Counter, Device
56
from vllm.version import __version__ as VLLM_VERSION
57
58

logger = init_logger(__name__)
59
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
60

61

62
63
64
65
66
67
68
69
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
    config = try_get_generation_config(
        model_config.model,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.revision,
    )

    if config is None:
70
71
        return {}

72
73
    return config.to_diff_dict()

74

75
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
76
77
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)

78
79
80
81
82
PromptComponents = Tuple[Optional[str], List[int],
                         Optional[MultiModalDataDict]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
                                Optional[MultiModalDataDict]]

83

84
85
86
87
88
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
89
90
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
91
92


93
94
@dataclass
class SchedulerContext:
95
    output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
96
97
98
                              List[SequenceGroupMetadata], SchedulerOutputs,
                              bool,
                              bool]] = field(default_factory=lambda: deque())
99
100
101
    request_outputs: List[Union[RequestOutput,
                                EmbeddingRequestOutput]] = field(
                                    default_factory=lambda: [])
102
103
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
104
105


106
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
107
    """An LLM engine that receives requests and generates texts.
108

Woosuk Kwon's avatar
Woosuk Kwon committed
109
    This is the main class for the vLLM engine. It receives requests
110
111
112
113
114
115
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

116
117
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
118

119
120
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
    :ref:`engine_args`)
121
122
123
124
125
126
127

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
128
        device_config: The configuration related to the device.
129
130
131
        lora_config (Optional): The configuration related to serving multi-LoRA.
        speculative_config (Optional): The configuration related to speculative
            decoding.
132
133
        executor_class: The model executor class for managing distributed
            execution.
134
135
        prompt_adapter_config (Optional): The configuration related to serving 
            prompt adapters.
136
        log_stats: Whether to log statistics.
137
        usage_context: Specified entry point, used for usage info collection.
138
    """
139

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

        return output

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

191
192
193
194
195
196
    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
197
        device_config: DeviceConfig,
198
        load_config: LoadConfig,
199
        lora_config: Optional[LoRAConfig],
200
        speculative_config: Optional[SpeculativeConfig],
201
        decoding_config: Optional[DecodingConfig],
202
        observability_config: Optional[ObservabilityConfig],
203
        prompt_adapter_config: Optional[PromptAdapterConfig],
204
        executor_class: Type[ExecutorBase],
205
        log_stats: bool,
yhu422's avatar
yhu422 committed
206
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
207
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
208
        input_registry: InputRegistry = INPUT_REGISTRY,
209
210
211
        # To improve performance, only final requests outputs may be required.
        # If this set to true, then no intermediate outputs will be returned.
        step_return_finished_only: bool = False,
212
213
    ) -> None:
        logger.info(
214
215
216
            "Initializing an LLM engine (v%s) with config: "
            "model=%r, speculative_config=%r, tokenizer=%r, "
            "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
217
            "override_neuron_config=%s, "
218
            "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
219
220
            "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
            "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
221
            "pipeline_parallel_size=%d, "
222
223
            "disable_custom_all_reduce=%s, quantization=%s, "
            "enforce_eager=%s, kv_cache_dtype=%s, "
224
            "quantization_param_path=%s, device_config=%s, "
225
            "decoding_config=%r, observability_config=%r, "
226
            "seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
227
228
            "num_scheduler_steps=%d, enable_prefix_caching=%s, "
            "use_async_output_proc=%s)",
229
            VLLM_VERSION,
230
231
232
233
234
235
            model_config.model,
            speculative_config,
            model_config.tokenizer,
            model_config.skip_tokenizer_init,
            model_config.tokenizer_mode,
            model_config.revision,
236
            model_config.override_neuron_config,
237
            model_config.rope_scaling,
238
            model_config.rope_theta,
239
240
241
242
243
244
245
            model_config.tokenizer_revision,
            model_config.trust_remote_code,
            model_config.dtype,
            model_config.max_model_len,
            load_config.download_dir,
            load_config.load_format,
            parallel_config.tensor_parallel_size,
246
            parallel_config.pipeline_parallel_size,
247
248
249
250
251
252
253
            parallel_config.disable_custom_all_reduce,
            model_config.quantization,
            model_config.enforce_eager,
            cache_config.cache_dtype,
            model_config.quantization_param_path,
            device_config.device,
            decoding_config,
254
            observability_config,
255
            model_config.seed,
256
            model_config.served_model_name,
257
            scheduler_config.use_v2_block_manager,
258
            scheduler_config.num_scheduler_steps,
259
            cache_config.enable_prefix_caching,
260
            model_config.use_async_output_proc,
261
        )
262
        # TODO(woosuk): Print more configs in debug mode.
263
264
265
        from vllm.plugins import load_general_plugins
        load_general_plugins()

266
267
        self.model_config = model_config
        self.cache_config = cache_config
268
        self.lora_config = lora_config
269
270
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
271
        self.device_config = device_config
272
        self.speculative_config = speculative_config
273
        self.load_config = load_config
274
        self.decoding_config = decoding_config or DecodingConfig()
275
        self.prompt_adapter_config = prompt_adapter_config
276
277
        self.observability_config = observability_config or ObservabilityConfig(
        )
278
        self.log_stats = log_stats
279
        self.step_return_finished_only = step_return_finished_only
280

281
        if not self.model_config.skip_tokenizer_init:
282
            self.tokenizer = self._init_tokenizer()
283
            self.detokenizer = Detokenizer(self.tokenizer)
284
            tokenizer_group = self.get_tokenizer_group()
285
286
        else:
            self.tokenizer = None
287
            self.detokenizer = None
288
289
290
291
292
293
294
295
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
296

297
        self.seq_counter = Counter()
298
299
        self.generation_config_fields = _load_generation_config_dict(
            model_config)
300

301
302
303
        self.input_registry = input_registry
        self.input_processor = input_registry.create_input_processor(
            model_config)
304

305
306
307
308
309
310
311
312
        self.model_executor = executor_class(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
313
            load_config=load_config,
314
            prompt_adapter_config=prompt_adapter_config,
315
            observability_config=self.observability_config,
316
        )
317

318
319
        if not self.model_config.embedding_mode:
            self._initialize_kv_caches()
320

yhu422's avatar
yhu422 committed
321
322
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
323
324
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
            usage_message.report_usage(
                get_architecture_class_name(model_config),
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
                    str(model_config.dtype),
                    "tensor_parallel_size":
                    parallel_config.tensor_parallel_size,
                    "block_size":
                    cache_config.block_size,
                    "gpu_memory_utilization":
                    cache_config.gpu_memory_utilization,

                    # Quantization
                    "quantization":
                    model_config.quantization,
                    "kv_cache_dtype":
343
                    str(cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
344
345
346
347

                    # Feature flags
                    "enable_lora":
                    bool(lora_config),
348
349
                    "enable_prompt_adapter":
                    bool(prompt_adapter_config),
yhu422's avatar
yhu422 committed
350
351
352
353
354
355
356
357
                    "enable_prefix_caching":
                    cache_config.enable_prefix_caching,
                    "enforce_eager":
                    model_config.enforce_eager,
                    "disable_custom_all_reduce":
                    parallel_config.disable_custom_all_reduce,
                })

358
359
360
361
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
362

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
            SchedulerContext()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.async_callbacks = [
            functools.partial(self._process_model_outputs,
                              ctx=self.scheduler_contexts[v_id])
            for v_id in range(self.parallel_config.pipeline_parallel_size)
        ]

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
        self.process_request_outputs_callback = None

383
        # Create the scheduler.
384
385
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
386
        self.scheduler = [
387
388
389
            Scheduler(
                scheduler_config, cache_config, lora_config,
                parallel_config.pipeline_parallel_size,
390
                self.async_callbacks[v_id]
391
                if model_config.use_async_output_proc else None)
392
            for v_id in range(parallel_config.pipeline_parallel_size)
393
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
394

395
396
        # Metric Logging.
        if self.log_stats:
397
398
399
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
400
401
402
403
404
405
406
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

407
408
409
410
411
412
413
414
415
416
417
418
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        labels=dict(model_name=model_config.served_model_name),
                        max_model_len=self.model_config.max_model_len),
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
419

420
421
422
423
424
425
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

426
427
428
429
430
431
432
433
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
434
                get_tokenizer_for_seq,
435
436
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
437
                    get_tokenizer_for_seq,
438
439
440
                ),
            ))

441
442
443
444
445
446
447
448
449
450
451
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
452
453
454
455
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
456
457
458
459
460
461
462
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

463
    @classmethod
464
465
    def _get_executor_cls(cls,
                          engine_config: EngineConfig) -> Type[ExecutorBase]:
466
467
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
468
        # Initialize the cluster and specify the executor class.
469
470
471
472
473
474
475
476
477
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
478
479
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
480
        elif engine_config.device_config.device_type == "tpu":
481
482
483
484
485
486
487
488
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutor
                executor_class = RayTPUExecutor
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutor
                executor_class = TPUExecutor
489
        elif engine_config.device_config.device_type == "cpu":
490
491
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
492
493
494
        elif engine_config.device_config.device_type == "openvino":
            from vllm.executor.openvino_executor import OpenVINOExecutor
            executor_class = OpenVINOExecutor
495
496
497
498
499
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutor
                executor_class = RayXPUExecutor
500
501
502
503
504
505
506
            elif distributed_executor_backend == "mp":
                # FIXME(kunshang):
                # spawn needs calling `if __name__ == '__main__':``
                # fork is not supported for xpu start new process.
                logger.error(
                    "Both start methods (spawn and fork) have issue "
                    "on XPU if you use mp backend, Please try ray instead.")
507
508
509
            else:
                from vllm.executor.xpu_executor import XPUExecutor
                executor_class = XPUExecutor
510
        elif distributed_executor_backend == "ray":
511
            initialize_ray_cluster(engine_config.parallel_config)
512
513
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
514
515
516
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutor)
517
518
519
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
520
            executor_class = MultiprocessingGPUExecutor
521
522
523
        else:
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor
524
525
526
527
528
529
530
531
532
533
534
535
536
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()
        executor_class = cls._get_executor_cls(engine_config)
537
        # Create the LLM engine.
yhu422's avatar
yhu422 committed
538
        engine = cls(
539
            **engine_config.to_dict(),
yhu422's avatar
yhu422 committed
540
541
542
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
543
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
544
        )
545

546
        return engine
547

548
549
550
551
552
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

553
554
555
556
557
558
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

559
560
561
562
    MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
                                   "skip_tokenizer_init is True")

    def get_tokenizer_group(
563
564
565
566
567
568
569
570
571
572
573
574
575
        self,
        group_type: Type[_G] = BaseTokenizerGroup,
        *,
        missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
    ) -> _G:
        tokenizer_group = self.tokenizer

        if tokenizer_group is None:
            raise ValueError(missing_msg)
        if not isinstance(tokenizer_group, group_type):
            raise TypeError("Invalid type of tokenizer group. "
                            f"Expected type: {group_type}, but "
                            f"found type: {type(tokenizer_group)}")
576

577
        return tokenizer_group
578

579
    def get_tokenizer(
580
581
582
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
583
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
584

585
586
587
588
589
590
    def _init_tokenizer(self) -> BaseTokenizerGroup:
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
            parallel_config=self.parallel_config,
            enable_lora=bool(self.lora_config))
591

592
593
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
594
        self.cache_config.verify_with_parallel_config(self.parallel_config)
595
596
597
598
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
599
600
601
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
602

603
604
605
606
607
608
609
610
611
612
613
614
615
    def _get_bos_token_id(self,
                          lora_request: Optional[LoRARequest] = None
                          ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for BOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id

    def _get_eos_token_id(self,
                          lora_request: Optional[LoRARequest] = None
                          ) -> Optional[int]:
616
617
618
619
620
621
622
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

623
    def _get_decoder_start_token_id(self) -> Optional[int]:
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
        '''
        Obtain the decoder start token id employed by an encoder/decoder
        model. Returns None for non-encoder/decoder models or if the
        model config is unavailable.
        '''

        if not self.is_encoder_decoder_model():
            logger.warning("Using None for decoder start token id because "
                           "this is not an encoder/decoder model.")
            return None

        if (self.model_config is None or self.model_config.hf_config is None):
            logger.warning("Using None for decoder start token id because "
                           "model config is not available.")
            return None

        dec_start_token_id = getattr(self.model_config.hf_config,
                                     'decoder_start_token_id', None)
        if dec_start_token_id is None:
            logger.warning("Falling back on <BOS> for decoder start token id "
                           "because decoder start token id is not available.")
            dec_start_token_id = self._get_bos_token_id()

        return dec_start_token_id

649
650
651
    def _add_processed_request(
        self,
        request_id: str,
652
        processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
653
654
655
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
656
        prompt_adapter_request: Optional[PromptAdapterRequest],
657
        trace_headers: Optional[Mapping[str, str]] = None,
658
    ) -> None:
659
        self._validate_model_inputs(processed_inputs)
660
661
662
663
664
665
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
        eos_token_id = self._get_eos_token_id(lora_request)

        seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
666
                       lora_request, prompt_adapter_request)
667

668
669
670
671
672
673
674
675
676
677
        encoder_seq = None
        if 'encoder_prompt_token_ids' in processed_inputs:
            encoder_seq = Sequence(seq_id,
                                   processed_inputs,
                                   block_size,
                                   eos_token_id,
                                   lora_request,
                                   prompt_adapter_request,
                                   from_decoder_prompt=False)

678
679
680
681
682
683
684
685
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
686
                trace_headers=trace_headers,
687
688
                prompt_adapter_request=prompt_adapter_request,
                encoder_seq=encoder_seq)
689
690
691
692
693
694
695
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
696
697
                prompt_adapter_request=prompt_adapter_request,
                encoder_seq=encoder_seq)
698
699
700
701
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

702
703
704
705
706
707
708
709
710
711
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
712

713
    _LLMInputComponentsType = Tuple[str, List[int]]
714
715
716

    def _prepare_decoder_input_ids_for_generation(
        self,
717
        decoder_input_ids: Optional[List[int]],
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
    ) -> List[int]:
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

        Based on

        https://github.com/huggingface/transformers/blob/
        4037a2b5b1278736e566aec12e169100275545ea/
        src/transformers/generation/utils.py

        specifically GenerationMixin._prepare_decoder_input_ids_for_generation()

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """

739
        decoder_start_token_id = self._get_decoder_start_token_id()
740
741
742
743
744
        assert decoder_start_token_id is not None

        if decoder_input_ids is None:
            # no decoder prompt input ->
            # use decoder_start_token_id as decoder_input_ids
745
            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
746
747
748
749
750
751
752
753
754
755

        if (len(decoder_input_ids) == 0
                or decoder_input_ids[0] != decoder_start_token_id):
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

    def _tokenize_prompt(
        self,
        prompt: str,
756
757
        request_id: str,
        lora_request: Optional[LoRARequest],
758
759
    ) -> List[int]:
        '''
760
        Wrapper around application of the model's tokenizer.
761
762
763
764
765
766
767
768
769
770
771
772

        Arguments:

        * prompt
        * request_id
        * lora_request

        Returns:

        * prompt token ids
        '''

773
774
        tokenizer = self.get_tokenizer_group(
            missing_msg="prompts must be None if skip_tokenizer_init is True")
775

776
777
778
        return tokenizer.encode(request_id=request_id,
                                prompt=prompt,
                                lora_request=lora_request)
779

780
    def _extract_prompt_components(
781
        self,
782
783
784
785
        inputs: SingletonPromptInputs,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
    ) -> PromptComponents:
786
        '''
787
        Extract the components of any single encoder or decoder input prompt.
788
789
790
791
792

        Arguments:

        * request_id
        * inputs: single encoder or decoder input prompt
793
        * lora_request: this is only valid for decoder prompts
794
795
796
797
798

        Returns:

        * prompt
        * prompt_token_ids
799
        * multi_modal_data
800
801
        '''

802
        if isinstance(inputs, str):
803
804
805
806
            prompt = inputs
            prompt_token_ids = self._tokenize_prompt(
                prompt,
                request_id=request_id,
807
                lora_request=lora_request,
808
            )
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
            multi_modal_data = None
        elif isinstance(inputs, dict):
            if "prompt_token_ids" in inputs:
                prompt = None
                prompt_token_ids = inputs["prompt_token_ids"]
            else:
                # NOTE: This extra assignment is required to pass mypy
                prompt = parsed_prompt = inputs["prompt"]
                prompt_token_ids = self._tokenize_prompt(
                    parsed_prompt,
                    request_id=request_id,
                    lora_request=lora_request,
                )

            multi_modal_data = inputs.get("multi_modal_data")
824
        else:
825
            assert_never(inputs)
826

827
        return prompt, prompt_token_ids, multi_modal_data
828

829
830
831
832
833
834
835
836
837
    def _apply_prompt_adapter(
        self,
        prompt_token_ids: List[int],
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> List[int]:
        if prompt_adapter_request:
            prompt_token_ids = (
                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
                + prompt_token_ids)
838

839
        return prompt_token_ids
840

841
    def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
        '''
        Specifically for encoder/decoder models:
        generate a default decoder prompt for when
        the user specifies only the encoder prompt.

        Encoder/decoder models utilize the decoder
        prompt in different ways; as new models are
        added, it is intended that this function
        will be extended to produce differing
        default decoder prompts, depending on the
        model variety.

        Absent a special case, the default behavior
        of this method is to mirror the behavior of
        the HuggingFace (HF) GenerationMixin for a None
        decoder prompt, which is to employ a logit processor
        setting to force the first decoded token to be <BOS>.
        Here, this behavior is approximated by having the
        "default" decoder prompt be <BOS>.

        However, it is possible that in the future
        other models may have different or more 
        complex logic for the default decoder prompt.
        This motivates having a special helper method
        for default decoder prompts.

        Returns:

        * prompt_token_ids
        '''

        bos_token_id = self._get_bos_token_id()
        assert bos_token_id is not None
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
        return [bos_token_id]

    def _build_enc_dec_llm_inputs(
        self,
        encoder_comps: PromptComponents,
        decoder_comps: DecoderPromptComponents,
    ) -> EncoderDecoderLLMInputs:
        encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
        decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps

        if encoder_mm_data is not None or decoder_mm_data is not None:
            raise ValueError("Multi-modal encoder-decoder models are "
                             "not supported yet")

        decoder_prompt_ids = (
            self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))

        return EncoderDecoderLLMInputs(
            prompt_token_ids=decoder_prompt_ids,
            prompt=decoder_prompt,
            encoder_prompt_token_ids=encoder_prompt_ids,
            encoder_prompt=encoder_prompt,
        )
898
899
900
901

    def _process_encoder_decoder_prompt(
        self,
        inputs: PromptInputs,
902
903
        request_id: str,
    ) -> EncoderDecoderLLMInputs:
904
905
        '''
        For encoder/decoder models only:
906
907
        Process an input prompt into an
        :class:`EncoderDecoderLLMInputs` instance.
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933

        There are two types of input prompts:
        singleton prompts which carry only the
        encoder prompt, and explicit encoder/decoder
        prompts which carry both the encoder and the
        decoder prompts as member variables.

        This function handles the following scenarios:
        * Singleton encoder prompt: extract encoder prompt
          token ids & infer default decoder prompt token ids
        * Explicit encoder/decoder prompt: extract encoder
          and decoder prompt token ids

        Note that for Explicit encoder/decoder prompts,
        each sub-prompt (encoder or decoder prompt) can
        have any possible singleton type; thus this
        method relies on helper functions to obtain
        token ids for the sub-prompts.
        
        Arguments:

        * inputs: an input prompt
        * request_id

        Returns:

934
        * :class:`EncoderDecoderLLMInputs` instance
935
936
        '''

937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
        encoder_comps: PromptComponents
        decoder_comps: DecoderPromptComponents

        if is_explicit_encoder_decoder_prompt(inputs):
            encoder_comps = self._extract_prompt_components(
                inputs["encoder_prompt"],
                request_id=request_id,
            )

            if (decoder_input := inputs["decoder_prompt"]) is None:
                decoder_comps = None, None, None
            else:
                decoder_comps = self._extract_prompt_components(
                    decoder_input,
                    request_id=request_id,
                )
953
        else:
954
955
956
957
            encoder_comps = self._extract_prompt_components(
                inputs,
                request_id=request_id,
            )
958

959
            decoder_comps = None, None, None
960

961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
        return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)

    def _build_decoder_only_llm_inputs(
        self,
        prompt_comps: PromptComponents,
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> LLMInputs:
        prompt, prompt_token_ids, multi_modal_data = prompt_comps

        prompt_token_ids = self._apply_prompt_adapter(
            prompt_token_ids, prompt_adapter_request=prompt_adapter_request)

        return LLMInputs(prompt_token_ids=prompt_token_ids,
                         prompt=prompt,
                         multi_modal_data=multi_modal_data)
976
977

    def _process_decoder_only_prompt(
978
        self,
979
980
        inputs: SingletonPromptInputs,
        request_id: str,
981
        lora_request: Optional[LoRARequest] = None,
982
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
983
    ) -> LLMInputs:
984
985
        '''
        For decoder-only models:
986
        Process an input prompt into an :class:`LLMInputs` instance.
987
988
989
990
991

        Arguments:

        * inputs: input prompt
        * request_id
992
        * lora_request
993
994
995
996
        * prompt_adapter_request

        Returns:

997
        * :class:`LLMInputs` instance
998
999
        '''

1000
1001
1002
1003
1004
        prompt_comps = self._extract_prompt_components(
            inputs,
            request_id=request_id,
            lora_request=lora_request,
        )
1005

1006
1007
1008
1009
        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )
1010
1011
1012
1013

    def process_model_inputs(
        self,
        inputs: PromptInputs,
1014
        request_id: str,
1015
1016
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1017
    ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
1018

1019
1020
1021
1022
1023
1024
1025
1026
        if self.is_encoder_decoder_model():
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
            model_inputs = self._process_encoder_decoder_prompt(
                inputs,
                request_id=request_id,
            )
        else:
1027
1028
1029
1030
            if is_explicit_encoder_decoder_prompt(inputs):
                raise ValueError("Cannot pass encoder-decoder prompt "
                                 "to decoder-only models")

1031
1032
1033
1034
1035
1036
1037
1038
1039
            # Decoder-only operation
            model_inputs = self._process_decoder_only_prompt(
                inputs,
                request_id=request_id,
                lora_request=lora_request,
                prompt_adapter_request=prompt_adapter_request,
            )

        return self.input_processor(model_inputs)
1040

1041
1042
1043
    def add_request(
        self,
        request_id: str,
1044
        inputs: PromptInputs,
1045
        params: Union[SamplingParams, PoolingParams],
1046
        arrival_time: Optional[float] = None,
1047
        lora_request: Optional[LoRARequest] = None,
1048
        trace_headers: Optional[Mapping[str, str]] = None,
1049
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1050
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
1051
        """Add a request to the engine's request pool.
1052
1053

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
1054
        scheduler as `engine.step()` is called. The exact scheduling policy is
1055
1056
1057
1058
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
1059
1060
1061
1062
1063
1064
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
1065
            arrival_time: The arrival time of the request. If None, we use
1066
                the current monotonic time.
1067
            trace_headers: OpenTelemetry trace headers.
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
1092
        """
1093
1094
1095
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
1096
        if arrival_time is None:
1097
            arrival_time = time.time()
1098

1099
        processed_inputs = self.process_model_inputs(
1100
            inputs,
1101
1102
            request_id=request_id,
            lora_request=lora_request,
1103
1104
            prompt_adapter_request=prompt_adapter_request,
        )
1105

1106
1107
1108
1109
1110
1111
        self._add_processed_request(
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
1112
            prompt_adapter_request=prompt_adapter_request,
1113
            trace_headers=trace_headers,
1114
        )
1115
1116
1117
1118
1119
1120

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
1121
1122
        arrival_time: float,
        lora_request: Optional[LoRARequest],
1123
        trace_headers: Optional[Mapping[str, str]] = None,
1124
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1125
        encoder_seq: Optional[Sequence] = None,
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

1136
1137
1138
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
1139

1140
        sampling_params.update_from_generation_config(
1141
            self.generation_config_fields, seq.eos_token_id)
1142

1143
        # Create the sequence group.
1144
1145
1146
1147
1148
1149
1150
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
1151
1152
            prompt_adapter_request=prompt_adapter_request,
            encoder_seq=encoder_seq)
1153

1154
1155
1156
1157
1158
1159
1160
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
1161
1162
        arrival_time: float,
        lora_request: Optional[LoRARequest],
1163
        prompt_adapter_request: Optional[PromptAdapterRequest],
1164
        encoder_seq: Optional[Sequence] = None,
1165
1166
1167
1168
1169
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
1170
1171
1172
1173
1174
1175
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
1176
1177
            prompt_adapter_request=prompt_adapter_request,
            encoder_seq=encoder_seq)
1178
        return seq_group
1179

Antoni Baum's avatar
Antoni Baum committed
1180
1181
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
1182
1183

        Args:
Antoni Baum's avatar
Antoni Baum committed
1184
            request_id: The ID(s) of the request to abort.
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
1196
        """
1197
1198
        for scheduler in self.scheduler:
            scheduler.abort_seq_group(request_id)
1199

1200
1201
1202
1203
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

1204
1205
1206
1207
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

1208
1209
1210
1211
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

1212
1213
1214
1215
1216
1217
1218
1219
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

1220
    def get_num_unfinished_requests(self) -> int:
1221
        """Gets the number of unfinished requests."""
1222
1223
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
1224

1225
    def has_unfinished_requests(self) -> bool:
1226
        """Returns True if there are unfinished requests."""
1227
1228
1229
1230
1231
1232
1233
1234
1235
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
1236

1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
    def _process_sequence_group_outputs(
        self,
        seq_group: SequenceGroup,
        outputs: List[EmbeddingSequenceGroupOutput],
    ) -> None:
        seq_group.embeddings = outputs[0].embeddings

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

1249
    def _process_model_outputs(self, ctx: SchedulerContext) -> None:
1250
        """Apply the model output to the sequences in the scheduled seq groups.
1251

1252
        virtual_engine: The engine id to operate on
1253
        
1254
1255
1256
1257
1258
1259
        is_async: Indicates whether this postprocessor runs in 
            parallel with the GPU forward pass and is processing 
            tokens from the previous step. If this is true, then
            no tokens need to be appended since it is already done
            externally (before the next schedule() call)
        
1260
1261
1262
1263
1264
        sampler_output: Used with multi-step execution to provide 
            sampler_output of each step
        is_last_output: Used with multi-step execution to indicate
            the last step (of each multi-step group)
            
1265
1266
        Returns RequestOutputs that can be returned to the client.
        """
1267
        now = time.time()
1268

1269
        if len(ctx.output_queue) == 0:
1270
1271
            return None

1272
1273
1274
        # Get pending async postprocessor
        (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
         is_last_step) = ctx.output_queue.popleft()
1275
        assert outputs is not None
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

        # Organize outputs by [step][sequence group] instead of
        # [sequence group][step].
        if len(outputs) > 1:
            outputs_by_sequence_group = create_output_by_sequence_group(
                outputs, num_seq_groups=len(seq_group_metadata_list))
        else:
            outputs_by_sequence_group = outputs

        finished_before: List[int] = []
1290
        finished_now: List[int] = []
1291
1292
        for i, seq_group_meta in enumerate(seq_group_metadata_list):
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1293

1294
            seq_group = scheduled_seq_group.seq_group
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310

            if seq_group.is_finished():
                finished_before.append(i)
                continue

            if len(outputs) > 1:
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

            if not is_async:
                seq_group.update_num_computed_tokens(
                    scheduled_seq_group.token_chunk_size)

            if outputs:
                for o in outputs:
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
                                o.model_forward_time)
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
                                o.model_execute_time)
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1325

1326
            if self.model_config.embedding_mode:
1327
                self._process_sequence_group_outputs(seq_group, output)
1328
1329
1330
1331
1332
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
                    self.output_processor.process_outputs(
                        seq_group, output, is_async)
1333

1334
1335
            if seq_group.is_finished():
                finished_now.append(i)
1336

1337
1338
1339
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1340

1341
1342
1343
1344
            seq_group = scheduled_seq_group.seq_group
            seq_group.maybe_set_first_token_time(now)
            request_output = RequestOutputFactory.create(seq_group)
            ctx.request_outputs.append(request_output)
1345

1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

        # For multi-step, do not create outputs each iteration
        if not is_last_step:
            # Immediately process request outputs here (if callback is given)
            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
            return

        # Create the outputs
        # Note: scheduled_seq_groups and seq_group_metadata_list
        # must match with the indices
        for i, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
1364

1365
            if i in finished_before or i in finished_now:
1366
1367
                continue  # Avoids double processing

1368
            seq_group = scheduled_seq_group.seq_group
1369
            seq_group.maybe_set_first_token_time(now)
1370
1371
1372
            if (seq_group.is_finished()
                    if self.step_return_finished_only else True):
                request_output = RequestOutputFactory.create(seq_group)
1373
                ctx.request_outputs.append(request_output)
1374
1375

        for seq_group in scheduler_outputs.ignored_seq_groups:
1376
            request_output = RequestOutputFactory.create(seq_group)
1377
            ctx.request_outputs.append(request_output)
1378

1379
1380
1381
1382
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1383

1384
1385
1386
1387
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
            # Log stats.
            self.do_log_stats(scheduler_outputs, outputs, finished_before)

            # Tracing
            self.do_tracing(scheduler_outputs)

        return None

    def _advance_to_next_step(
            self, output: List[SamplerOutput],
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

            seq_group.update_num_computed_tokens(
                seq_group_metadata.token_chunk_size)

            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
                    " (i.e sampling_params.n == 1 and no "
                    "sampling_params.best_of > 1)")
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
                seq.append_token_id(sample.output_token, sample.logprobs)
1424

1425
    def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1426
1427
        """Performs one decoding iteration and returns newly generated results.

1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

1443
            - Step 2: Calls the distributed executor to execute the model.
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
1465
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
1466
1467
1468
1469
1470
1471
1472
1473
1474
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
1475
        """
1476
1477
1478
1479
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1480

1481
        # For llm_engine, there is no pipeline parallel support, so the engine
1482
        # used is always 0.
1483
1484
        virtual_engine = 0

1485
1486
        # These are cached outputs from previous iterations. None if on first
        # iteration
1487
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1488
1489
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1490
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1491

1492
1493
        ctx = self.scheduler_contexts[virtual_engine]

1494
1495
1496
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1497
1498
1499
1500
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
1501
            # Schedule iteration
1502
            (seq_group_metadata_list, scheduler_outputs,
1503
1504
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1505

1506
1507
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1508

1509
1510
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1511
                self._process_model_outputs(ctx=ctx)
1512

1513
1514
1515
1516
1517
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
1518
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
1519
                    allow_async_output_proc)
1520
1521
1522

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1523

1524
        if not scheduler_outputs.is_empty():
1525
            finished_requests_ids = self.scheduler[
1526
                virtual_engine].get_and_reset_finished_requests_ids()
1527
1528
1529
1530
1531
1532

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1533
                self._get_last_sampled_token_ids(virtual_engine)
1534

1535
            execute_model_req = ExecuteModelRequest(
1536
1537
1538
1539
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1540
1541
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1542
1543
1544
1545
1546
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1547
            if allow_async_output_proc:
1548
1549
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1550

1551
1552
            output = self.model_executor.execute_model(
                execute_model_req=execute_model_req)
1553

1554
            # We need to do this here so that last step's sampled_token_ids can
1555
1556
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
1557
                self._update_cached_scheduler_output(virtual_engine, output)
1558
        else:
1559
1560
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1561
1562
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1563
            # No outputs in this case
1564
            output = []
Antoni Baum's avatar
Antoni Baum committed
1565

1566
1567
1568
1569
1570
1571
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
1572
            # clear the cache if we have finished all the steps.
1573
1574
1575
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()

1576
1577
1578
1579
1580
1581
            # Add results to the output_queue
            is_async = allow_async_output_proc
            is_last_step = True
            ctx.output_queue.append(
                (output, seq_group_metadata_list, scheduler_outputs, is_async,
                 is_last_step))
1582

1583
1584
1585
            if output and allow_async_output_proc:
                assert len(output) == 1, (
                    "Async postprocessor expects only a single output set")
1586

1587
1588
1589
                self._advance_to_next_step(
                    output[0], seq_group_metadata_list,
                    scheduler_outputs.scheduled_seq_groups)
1590

1591
            # Check if need to run the usual non-async path
1592
            if not allow_async_output_proc:
1593
                self._process_model_outputs(ctx=ctx)
1594

1595
1596
                # Log stats.
                self.do_log_stats(scheduler_outputs, output)
1597

1598
1599
1600
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1601
            # Multi-step case
1602
            return ctx.request_outputs
1603

1604
        if not self.has_unfinished_requests():
1605
1606
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1607
                self._process_model_outputs(ctx=ctx)
1608
            assert len(ctx.output_queue) == 0
1609

1610
1611
1612
1613
1614
1615
1616
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
            self.model_executor.stop_remote_worker_execution_loop()

1617
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1618

1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
            raise AssertionError(("All running sequence groups should "
                                  "have the same remaining steps."))

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1642
1643
1644
1645
1646
1647
1648
1649
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None

1675
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1676
1677
1678
1679
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1680
1681
1682
1683
1684
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1685
1686
1687
1688
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1689
1690
1691
1692
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1693
1694
1695
1696
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
                     finished_before: Optional[List[int]] = None) -> None:
1697
1698
        """Forced log when no requests active."""
        if self.log_stats:
1699
1700
            stats = self._get_stats(scheduler_outputs, model_output,
                                    finished_before)
1701
            for logger in self.stat_loggers.values():
1702
                logger.log(stats)
1703

1704
1705
1706
1707
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
                   finished_before: Optional[List[int]] = None) -> Stats:
1708
1709
1710
1711
1712
1713
1714
1715
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
        """
1716
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1717

1718
1719
        # System State
        #   Scheduler State
1720
1721
1722
1723
1724
1725
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1726
1727

        # KV Cache Usage in %
1728
        num_total_gpu = self.cache_config.num_gpu_blocks
1729
1730
        gpu_cache_usage_sys = 0.
        if num_total_gpu is not None:
1731
1732
1733
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1734
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1735

1736
        num_total_cpu = self.cache_config.num_cpu_blocks
1737
        cpu_cache_usage_sys = 0.
1738
        if num_total_cpu is not None and num_total_cpu > 0:
1739
1740
1741
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1742
1743
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1744
1745
1746
1747
1748
1749
1750
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1751
1752
1753
1754
1755
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1756
1757
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        best_of_requests: List[int] = []
        n_requests: List[int] = []
        finished_reason_requests: List[str] = []

        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1771
        if scheduler_outputs is not None:
1772
1773
1774
1775
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1776
            num_generation_tokens_from_prefill_groups = 0.
1777
1778
1779
1780
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1781
1782
1783

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1784
1785
1786
1787
1788
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue

1789
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1790
                seq_group = scheduled_seq_group.seq_group
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
                        latency = seq_group.get_last_latency(now)
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
                    latency = seq_group.get_last_latency(now)
                    time_per_output_tokens_iter.append(latency)

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1819
                if seq_group.is_finished():
1820
                    # Latency timings
1821
1822
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
1823
1824
1825
1826
1827
1828
1829
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
1830
1831
1832
1833
                    if seq_group.sampling_params is not None:
                        best_of_requests.append(
                            seq_group.sampling_params.best_of)
                        n_requests.append(seq_group.sampling_params.n)
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1846
                actual_num_batched_tokens - num_prompt_tokens_iter +
1847
                num_generation_tokens_from_prefill_groups)
1848

1849
1850
1851
1852
1853
1854
1855
1856
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
        if model_output and (model_output[0].spec_decode_worker_metrics
                             is not None):
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1857
1858
        return Stats(
            now=now,
1859
1860
1861
1862
1863
1864
1865
1866
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1867
1868
1869
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1870
1871
1872
1873
1874
1875

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1876
            spec_decode_metrics=spec_decode_metrics,
1877
            num_preemption_iter=num_preemption_iter,
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
            best_of_requests=best_of_requests,
            n_requests=n_requests,
            finished_reason_requests=finished_reason_requests,
1888
1889
        )

1890
    def add_lora(self, lora_request: LoRARequest) -> bool:
1891
        return self.model_executor.add_lora(lora_request)
1892
1893

    def remove_lora(self, lora_id: int) -> bool:
1894
        return self.model_executor.remove_lora(lora_id)
1895

1896
    def list_loras(self) -> Set[int]:
1897
        return self.model_executor.list_loras()
1898

1899
1900
1901
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

1912
    def check_health(self) -> None:
1913
1914
        if self.tokenizer:
            self.tokenizer.check_health()
1915
        self.model_executor.check_health()
1916

1917
1918
1919
1920
1921
1922
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

    def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None:
        if self.tracer is None:
            return

        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
            # attribute names are based on
            # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
            seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
                                   self.model_config.model)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
                                   seq_group.request_id)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
                                   seq_group.sampling_params.temperature)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
                                   seq_group.sampling_params.top_p)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
                                   seq_group.sampling_params.max_tokens)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
                                   seq_group.sampling_params.best_of)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
                                   seq_group.sampling_params.n)
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
                                   seq_group.num_seqs())
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
                SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
                                   metrics.time_in_queue)
            seq_span.set_attribute(
                SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER,
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD,
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
                    metrics.model_execute_time)
1993
1994

    def is_encoder_decoder_model(self):
1995
        return self.model_config.is_encoder_decoder_model
1996
1997

    def is_embedding_model(self):
1998
        return self.model_config.is_embedding_model
1999
2000
2001

    def _validate_model_inputs(self, inputs: Union[LLMInputs,
                                                   EncoderDecoderLLMInputs]):
2002
2003
2004
2005
2006
2007
        if self.is_encoder_decoder_model():
            prompt_ids = inputs.get("encoder_prompt_token_ids")
        else:
            prompt_ids = inputs.get("prompt_token_ids")

        if prompt_ids is None or len(prompt_ids) == 0:
2008
            raise ValueError("Prompt cannot be empty")
2009

2010
        if self.model_config.is_multimodal_model:
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
            max_prompt_len = self.model_config.max_model_len

            if len(prompt_ids) > max_prompt_len:
                raise ValueError(
                    f"The prompt (total length {len(prompt_ids)}) is too long "
                    f"to fit into the model (context length {max_prompt_len}). "
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
2021
2022
2023
2024

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens