llm_engine.py 51.1 KB
Newer Older
Antoni Baum's avatar
Antoni Baum committed
1
import time
2
from contextlib import contextmanager
3
4
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
                    Mapping, Optional)
5
from typing import Sequence as GenericSequence
6
from typing import Set, Type, TypeVar, Union
7

8
from transformers import PreTrainedTokenizer
9

10
import vllm.envs as envs
11
12
13
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                         MultiModalConfig, ObservabilityConfig, ParallelConfig,
14
                         PromptAdapterConfig, SchedulerConfig,
15
                         SpeculativeConfig)
16
17
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.engine.arg_utils import EngineArgs
19
20
from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
                                 StatLoggerBase, Stats)
21
22
23
24
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
25
from vllm.executor.executor_base import ExecutorBase
26
from vllm.executor.ray_utils import initialize_ray_cluster
27
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
28
from vllm.logger import init_logger
29
from vllm.lora.request import LoRARequest
30
31
32
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
33
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.sampling_params import SamplingParams
35
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
36
37
                           PoolerOutput, SamplerOutput, Sequence,
                           SequenceGroup, SequenceGroupMetadata,
38
                           SequenceStatus)
39
40
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
41
from vllm.transformers_utils.config import try_get_generation_config
42
from vllm.transformers_utils.detokenizer import Detokenizer
43
44
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
                                                     get_tokenizer_group)
yhu422's avatar
yhu422 committed
45
46
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
47
from vllm.utils import Counter
48
from vllm.version import __version__ as VLLM_VERSION
49
50

logger = init_logger(__name__)
51
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
52

53

54
55
56
57
58
59
60
61
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
    config = try_get_generation_config(
        model_config.model,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.revision,
    )

    if config is None:
62
63
        return {}

64
65
    return config.to_diff_dict()

66

67
68
69
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)


70
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
71
    """An LLM engine that receives requests and generates texts.
72

Woosuk Kwon's avatar
Woosuk Kwon committed
73
    This is the main class for the vLLM engine. It receives requests
74
75
76
77
78
79
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

80
81
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
82

83
84
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
    :ref:`engine_args`)
85
86
87
88
89
90
91

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
92
        device_config: The configuration related to the device.
93
        lora_config (Optional): The configuration related to serving multi-LoRA.
94
95
        multimodal_config (Optional): The configuration related to multimodal 
            models.
96
97
        speculative_config (Optional): The configuration related to speculative
            decoding.
98
99
        executor_class: The model executor class for managing distributed
            execution.
100
101
        prompt_adapter_config (Optional): The configuration related to serving 
            prompt adapters.
102
        log_stats: Whether to log statistics.
103
        usage_context: Specified entry point, used for usage info collection.
104
    """
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

        return output

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

157
158
159
160
161
162
    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
163
        device_config: DeviceConfig,
164
        load_config: LoadConfig,
165
        lora_config: Optional[LoRAConfig],
166
        multimodal_config: Optional[MultiModalConfig],
167
        speculative_config: Optional[SpeculativeConfig],
168
        decoding_config: Optional[DecodingConfig],
169
        observability_config: Optional[ObservabilityConfig],
170
        prompt_adapter_config: Optional[PromptAdapterConfig],
171
        executor_class: Type[ExecutorBase],
172
        log_stats: bool,
yhu422's avatar
yhu422 committed
173
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
174
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
175
176
    ) -> None:
        logger.info(
177
178
179
            "Initializing an LLM engine (v%s) with config: "
            "model=%r, speculative_config=%r, tokenizer=%r, "
            "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
180
            "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
181
182
            "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
            "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
183
            "pipeline_parallel_size=%d, "
184
185
            "disable_custom_all_reduce=%s, quantization=%s, "
            "enforce_eager=%s, kv_cache_dtype=%s, "
186
            "quantization_param_path=%s, device_config=%s, "
187
            "decoding_config=%r, observability_config=%r, "
188
189
            "seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
            "enable_prefix_caching=%s)",
190
            VLLM_VERSION,
191
192
193
194
195
196
            model_config.model,
            speculative_config,
            model_config.tokenizer,
            model_config.skip_tokenizer_init,
            model_config.tokenizer_mode,
            model_config.revision,
197
            model_config.rope_scaling,
198
            model_config.rope_theta,
199
200
201
202
203
204
205
            model_config.tokenizer_revision,
            model_config.trust_remote_code,
            model_config.dtype,
            model_config.max_model_len,
            load_config.download_dir,
            load_config.load_format,
            parallel_config.tensor_parallel_size,
206
            parallel_config.pipeline_parallel_size,
207
208
209
210
211
212
213
            parallel_config.disable_custom_all_reduce,
            model_config.quantization,
            model_config.enforce_eager,
            cache_config.cache_dtype,
            model_config.quantization_param_path,
            device_config.device,
            decoding_config,
214
            observability_config,
215
            model_config.seed,
216
            model_config.served_model_name,
217
218
            scheduler_config.use_v2_block_manager,
            cache_config.enable_prefix_caching,
219
        )
220
221
222
223
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
224
        self.lora_config = lora_config
225
        self.multimodal_config = multimodal_config
226
227
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
228
        self.device_config = device_config
229
        self.speculative_config = speculative_config
230
        self.load_config = load_config
231
        self.decoding_config = decoding_config or DecodingConfig()
232
        self.prompt_adapter_config = prompt_adapter_config
233
234
        self.observability_config = observability_config or ObservabilityConfig(
        )
235
236
        self.log_stats = log_stats

237
        if not self.model_config.skip_tokenizer_init:
238
            self.tokenizer = self._init_tokenizer()
239
240
241
            self.detokenizer = Detokenizer(self.tokenizer)
        else:
            self.tokenizer = None
242
            self.detokenizer = None
243

244
        self.seq_counter = Counter()
245
246
        self.generation_config_fields = _load_generation_config_dict(
            model_config)
247

248
249
250
        self.input_processor = INPUT_REGISTRY.create_input_processor(
            self.model_config)

251
252
253
254
255
256
257
        self.model_executor = executor_class(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
258
            multimodal_config=multimodal_config,
259
            speculative_config=speculative_config,
260
            load_config=load_config,
261
            prompt_adapter_config=prompt_adapter_config,
262
        )
263

264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        init_success = False
        try:
            if not self.model_config.embedding_mode:
                self._initialize_kv_caches()

            # If usage stat is enabled, collect relevant info.
            if is_usage_stats_enabled():
                from vllm.model_executor.model_loader import (
                    get_architecture_class_name)
                usage_message.report_usage(
                    get_architecture_class_name(model_config),
                    usage_context,
                    extra_kvs={
                        # Common configuration
                        "dtype":
                        str(model_config.dtype),
                        "tensor_parallel_size":
                        parallel_config.tensor_parallel_size,
                        "block_size":
                        cache_config.block_size,
                        "gpu_memory_utilization":
                        cache_config.gpu_memory_utilization,

                        # Quantization
                        "quantization":
                        model_config.quantization,
                        "kv_cache_dtype":
291
                        str(cache_config.cache_dtype),
292
293
294
295

                        # Feature flags
                        "enable_lora":
                        bool(lora_config),
296
297
                        "enable_prompt_adapter":
                        bool(prompt_adapter_config),
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
                        "enable_prefix_caching":
                        cache_config.enable_prefix_caching,
                        "enforce_eager":
                        model_config.enforce_eager,
                        "disable_custom_all_reduce":
                        parallel_config.disable_custom_all_reduce,
                    })

            if self.tokenizer:
                # Ping the tokenizer to ensure liveness if it runs in a
                # different process.
                self.tokenizer.ping()

            # Create the scheduler.
            # NOTE: the cache_config here have been updated with the numbers of
            # GPU and CPU blocks, which are profiled in the distributed executor.
314
315
316
317
318
            self.scheduler = [
                Scheduler(scheduler_config, cache_config, lora_config,
                        parallel_config.pipeline_parallel_size)
                for _ in range(parallel_config.pipeline_parallel_size)
            ]
319
320
321

            # Metric Logging.
            if self.log_stats:
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
                if stat_loggers is not None:
                    self.stat_loggers = stat_loggers
                else:
                    self.stat_loggers = {
                        "logging":
                        LoggingStatLogger(
                            local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
                        "prometheus":
                        PrometheusStatLogger(
                            local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                            labels=dict(model_name=model_config.served_model_name),
                            max_model_len=self.model_config.max_model_len),
                    }
                    self.stat_loggers["prometheus"].info("cache_config",
                                                        self.cache_config)

            self.tracer = None
            if self.observability_config.otlp_traces_endpoint:
                self.tracer = init_tracer(
                    "vllm.llm_engine",
                    self.observability_config.otlp_traces_endpoint)
                
344
            def get_tokenizer_for_seq(self,
345
346
                                    sequence: Sequence) -> "PreTrainedTokenizer":
                return self.get_tokenizer_group().get_lora_tokenizer(
347
                    sequence.lora_request)
348

349
350
351
352
353
354
355
356
            # Create sequence output processor, e.g. for beam search or
            # speculative decoding.
            self.output_processor = (
                SequenceGroupOutputProcessor.create_output_processor(
                    self.scheduler_config,
                    self.detokenizer,
                    self.scheduler,
                    self.seq_counter,
357
                    get_tokenizer_for_seq,
358
359
                    stop_checker=StopChecker(
                        self.scheduler_config.max_model_len,
360
                        get_tokenizer_for_seq,
361
362
363
364
365
366
367
368
                    ),
                ))
            init_success = True
        finally:
            if not init_success:
                # Ensure that model_executor is shut down if LLMEngine init
                # failed
                self.model_executor.shutdown()
369
        
370

371
372
373
374
375
376
377
378
379
380
381
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
382
383
384
385
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
386
387
388
389
390
391
392
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

393
    @classmethod
394
395
    def _get_executor_cls(cls,
                          engine_config: EngineConfig) -> Type[ExecutorBase]:
396
397
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
398
        # Initialize the cluster and specify the executor class.
399
400
401
402
403
404
405
406
407
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
408
409
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
410
411
412
        elif engine_config.device_config.device_type == "tpu":
            from vllm.executor.tpu_executor import TPUExecutor
            executor_class = TPUExecutor
413
        elif engine_config.device_config.device_type == "cpu":
414
415
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
416
417
418
        elif engine_config.device_config.device_type == "openvino":
            from vllm.executor.openvino_executor import OpenVINOExecutor
            executor_class = OpenVINOExecutor
419
420
421
422
423
424
425
426
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutor
                executor_class = RayXPUExecutor
            else:
                from vllm.executor.xpu_executor import XPUExecutor
                executor_class = XPUExecutor
427
        elif distributed_executor_backend == "ray":
428
            initialize_ray_cluster(engine_config.parallel_config)
429
430
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
431
432
433
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutor)
434
435
436
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
437
            executor_class = MultiprocessingGPUExecutor
438
439
440
        else:
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor
441
442
443
444
445
446
447
448
449
450
451
452
453
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()
        executor_class = cls._get_executor_cls(engine_config)
454
        # Create the LLM engine.
yhu422's avatar
yhu422 committed
455
        engine = cls(
456
            **engine_config.to_dict(),
yhu422's avatar
yhu422 committed
457
458
459
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
460
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
461
        )
462

463
        return engine
464

465
466
467
468
469
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

470
471
472
473
474
475
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

476
477
478
479
480
481
482
483
484
485
486
    MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
                                   "skip_tokenizer_init is True")

    def get_tokenizer_group(
            self,
            fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
        if self.tokenizer is None:
            raise ValueError(fail_msg)

        return self.tokenizer

487
488
489
490
491
    def get_tokenizer(
            self,
            lora_request: Optional[LoRARequest] = None
    ) -> "PreTrainedTokenizer":
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
492

493
494
495
496
    # def get_tokenizer_for_seq(self,
    #                           sequence: Sequence) -> "PreTrainedTokenizer":
    #     return self.get_tokenizer_group().get_lora_tokenizer(
    #         sequence.lora_request)
497

498
    def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
499
        init_kwargs = dict(
500
            tokenizer_id=self.model_config.tokenizer,
501
502
503
504
505
506
507
            enable_lora=bool(self.lora_config),
            max_num_seqs=self.scheduler_config.max_num_seqs,
            max_input_length=None,
            tokenizer_mode=self.model_config.tokenizer_mode,
            trust_remote_code=self.model_config.trust_remote_code,
            revision=self.model_config.tokenizer_revision)
        init_kwargs.update(tokenizer_init_kwargs)
508
509
510

        return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
                                   **init_kwargs)
511

512
513
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
514
        self.cache_config.verify_with_parallel_config(self.parallel_config)
515
516
517
518
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
519
520
521
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
522

523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    def _get_eos_token_id(
            self, lora_request: Optional[LoRARequest]) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

    def _add_processed_request(
        self,
        request_id: str,
        processed_inputs: LLMInputs,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
539
        prompt_adapter_request: Optional[PromptAdapterRequest],
540
        trace_headers: Optional[Mapping[str, str]] = None,
541
542
543
544
545
546
547
    ) -> None:
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
        eos_token_id = self._get_eos_token_id(lora_request)

        seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
548
                       lora_request, prompt_adapter_request)
549
550
551
552
553
554
555
556
557

        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
558
                trace_headers=trace_headers,
559
                prompt_adapter_request=prompt_adapter_request)
560
561
562
563
564
565
566
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
567
                prompt_adapter_request=prompt_adapter_request)
568
569
570
571
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

572
573
574
575
576
577
578
579
580
581
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
582
583

    def process_model_inputs(
584
        self,
585
586
        request_id: str,
        inputs: PromptInputs,
587
        lora_request: Optional[LoRARequest] = None,
588
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
589
590
591
592
593
594
595
596
597
598
599
600
601
602
    ) -> LLMInputs:
        if isinstance(inputs, str):
            inputs = {"prompt": inputs}

        if "prompt_token_ids" not in inputs:
            tokenizer = self.get_tokenizer_group("prompts must be None if "
                                                 "skip_tokenizer_init is True")

            prompt_token_ids = tokenizer.encode(request_id=request_id,
                                                prompt=inputs["prompt"],
                                                lora_request=lora_request)
        else:
            prompt_token_ids = inputs["prompt_token_ids"]

603
604
605
606
607
        if prompt_adapter_request:
            prompt_token_ids = \
                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         + prompt_token_ids

608
609
610
611
612
        llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
                               prompt=inputs.get("prompt"),
                               multi_modal_data=inputs.get("multi_modal_data"))

        return self.input_processor(llm_inputs)
613

614
615
616
    def add_request(
        self,
        request_id: str,
617
        inputs: PromptInputs,
618
        params: Union[SamplingParams, PoolingParams],
619
        arrival_time: Optional[float] = None,
620
        lora_request: Optional[LoRARequest] = None,
621
        trace_headers: Optional[Mapping[str, str]] = None,
622
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
623
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
624
        """Add a request to the engine's request pool.
625
626

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
627
        scheduler as `engine.step()` is called. The exact scheduling policy is
628
629
630
631
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
632
633
634
635
636
637
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
638
            arrival_time: The arrival time of the request. If None, we use
639
                the current monotonic time.
640
            trace_headers: OpenTelemetry trace headers.
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
665
        """
666
667
668
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
669
        if arrival_time is None:
670
            arrival_time = time.time()
671

672
673
674
675
676
        processed_inputs = self.process_model_inputs(
            request_id=request_id,
            inputs=inputs,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
677

678
679
680
681
682
683
        self._add_processed_request(
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
684
            prompt_adapter_request=prompt_adapter_request,
685
            trace_headers=trace_headers,
686
        )
687
688
689
690
691
692

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
693
694
        arrival_time: float,
        lora_request: Optional[LoRARequest],
695
        trace_headers: Optional[Mapping[str, str]] = None,
696
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
697
698
699
700
701
702
703
704
705
706
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

707
708
709
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
710

711
        sampling_params.update_from_generation_config(
712
            self.generation_config_fields, seq.eos_token_id)
713

714
        # Create the sequence group.
715
716
717
718
719
720
721
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
722
            prompt_adapter_request=prompt_adapter_request)
723

724
725
726
727
728
729
730
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
731
732
        arrival_time: float,
        lora_request: Optional[LoRARequest],
733
        prompt_adapter_request: Optional[PromptAdapterRequest],
734
735
736
737
738
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
739
740
741
742
743
744
745
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
            prompt_adapter_request=prompt_adapter_request)
746
        return seq_group
747

Antoni Baum's avatar
Antoni Baum committed
748
749
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
750
751

        Args:
Antoni Baum's avatar
Antoni Baum committed
752
            request_id: The ID(s) of the request to abort.
753
754
755
756
757
758
759
760
761
762
763

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
764
        """
765
766
        for scheduler in self.scheduler:
            scheduler.abort_seq_group(request_id)
767

768
769
770
771
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

772
773
774
775
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

776
    def get_num_unfinished_requests(self) -> int:
777
        """Gets the number of unfinished requests."""
778
779
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
780

781
    def has_unfinished_requests(self) -> bool:
782
        """Returns True if there are unfinished requests."""
783
784
785
786
787
788
789
790
791
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
792

793
794
795
796
797
798
799
800
801
802
803
804
    def _process_sequence_group_outputs(
        self,
        seq_group: SequenceGroup,
        outputs: List[EmbeddingSequenceGroupOutput],
    ) -> None:
        seq_group.embeddings = outputs[0].embeddings

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

805
    def _process_model_outputs(
806
        self,
807
        output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
808
        scheduled_seq_groups: List[ScheduledSequenceGroup],
809
810
        ignored_seq_groups: List[SequenceGroup],
        seq_group_metadata_list: List[SequenceGroupMetadata],
811
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
812
        """Apply the model output to the sequences in the scheduled seq groups.
813

814
815
816
        Returns RequestOutputs that can be returned to the client.
        """

817
        now = time.time()
818
819
820
821

        # Organize outputs by [sequence group][step] instead of
        # [step][sequence group].
        output_by_sequence_group = create_output_by_sequence_group(
822
            output, num_seq_groups=len(scheduled_seq_groups))
823

824
        # Update the scheduled sequence groups with the model outputs.
825
826
827
        for scheduled_seq_group, outputs, seq_group_meta in zip(
                scheduled_seq_groups, output_by_sequence_group,
                seq_group_metadata_list):
828
            seq_group = scheduled_seq_group.seq_group
829
830
            seq_group.update_num_computed_tokens(
                scheduled_seq_group.token_chunk_size)
831
832
833
            if self.model_config.embedding_mode:
                self._process_sequence_group_outputs(seq_group, outputs)
                continue
834

835
836
            self.output_processor.process_prompt_logprob(seq_group, outputs)
            if seq_group_meta.do_sample:
837
                self.output_processor.process_outputs(seq_group, outputs)
838
839

        # Free the finished sequence groups.
840
841
        for scheduler in self.scheduler:
            scheduler.free_finished_seq_groups()
842
843

        # Create the outputs.
844
845
        request_outputs: List[Union[RequestOutput,
                                    EmbeddingRequestOutput]] = []
846
847
        for scheduled_seq_group in scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
848
            seq_group.maybe_set_first_token_time(now)
849
            request_output = RequestOutputFactory.create(seq_group)
850
            request_outputs.append(request_output)
851
        for seq_group in ignored_seq_groups:
852
            request_output = RequestOutputFactory.create(seq_group)
853
854
855
            request_outputs.append(request_output)
        return request_outputs

856
    def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
857
858
        """Performs one decoding iteration and returns newly generated results.

859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

874
            - Step 2: Calls the distributed executor to execute the model.
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
896
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
897
898
899
900
901
902
903
904
905
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
906
        """
907
908
909
910
911
912
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
        seq_group_metadata_list, scheduler_outputs = self.scheduler[
            0].schedule()
Antoni Baum's avatar
Antoni Baum committed
913

914
        if not scheduler_outputs.is_empty():
915
916
            finished_requests_ids = self.scheduler[
                0].get_and_reset_finished_requests_ids()
917
            execute_model_req = ExecuteModelRequest(
918
919
920
921
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
922
923
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
Mor Zusman's avatar
Mor Zusman committed
924
                finished_requests_ids=finished_requests_ids)
925
926
            output = self.model_executor.execute_model(
                execute_model_req=execute_model_req)
927
928
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
929

930
931
        request_outputs = self._process_model_outputs(
            output, scheduler_outputs.scheduled_seq_groups,
932
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
933
934

        # Log stats.
935
        self.do_log_stats(scheduler_outputs, output)
936

937
938
939
        # Tracing
        self.do_tracing(scheduler_outputs)

940
        if not self.has_unfinished_requests():
941
942
943
944
945
946
947
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
            self.model_executor.stop_remote_worker_execution_loop()

948
        return request_outputs
Antoni Baum's avatar
Antoni Baum committed
949

950
951
952
953
954
955
956
957
958
959
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

960
961
962
963
    def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
964
965
        """Forced log when no requests active."""
        if self.log_stats:
966
967
            for logger in self.stat_loggers.values():
                logger.log(self._get_stats(scheduler_outputs, model_output))
968

969
970
971
972
973
974
975
976
977
978
979
980
    def _get_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs],
            model_output: Optional[List[SamplerOutput]] = None) -> Stats:
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
        """
981
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
982

983
984
        # System State
        #   Scheduler State
985
986
987
988
989
990
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
991
992

        # KV Cache Usage in %
993
        num_total_gpu = self.cache_config.num_gpu_blocks
994
995
        gpu_cache_usage_sys = 0.
        if num_total_gpu is not None:
996
997
998
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
999
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1000

1001
        num_total_cpu = self.cache_config.num_cpu_blocks
1002
        cpu_cache_usage_sys = 0.
1003
        if num_total_cpu is not None and num_total_cpu > 0:
1004
1005
1006
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1007
1008
1009
1010
1011
1012
1013
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1014
1015
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        best_of_requests: List[int] = []
        n_requests: List[int] = []
        finished_reason_requests: List[str] = []

        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1029
        if scheduler_outputs is not None:
1030
            num_generation_tokens_from_prefill_groups = 0.
1031
1032
1033
1034
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1035
1036
1037
1038

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1039
                seq_group = scheduled_seq_group.seq_group
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
                        latency = seq_group.get_last_latency(now)
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
                    latency = seq_group.get_last_latency(now)
                    time_per_output_tokens_iter.append(latency)

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1068
                if seq_group.is_finished():
1069
                    # Latency timings
1070
1071
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
1072

1073
1074
1075
1076
1077
1078
1079
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
1080
1081
1082
1083
                    if seq_group.sampling_params is not None:
                        best_of_requests.append(
                            seq_group.sampling_params.best_of)
                        n_requests.append(seq_group.sampling_params.n)
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
                scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
                num_generation_tokens_from_prefill_groups)
1098

1099
1100
1101
1102
1103
1104
1105
1106
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
        if model_output and (model_output[0].spec_decode_worker_metrics
                             is not None):
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1107
1108
        return Stats(
            now=now,
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1123
            spec_decode_metrics=spec_decode_metrics,
1124
            num_preemption_iter=num_preemption_iter,
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
            best_of_requests=best_of_requests,
            n_requests=n_requests,
            finished_reason_requests=finished_reason_requests,
1135
1136
        )

1137
    def add_lora(self, lora_request: LoRARequest) -> bool:
1138
        return self.model_executor.add_lora(lora_request)
1139
1140

    def remove_lora(self, lora_id: int) -> bool:
1141
        return self.model_executor.remove_lora(lora_id)
1142

1143
    def list_loras(self) -> Set[int]:
1144
        return self.model_executor.list_loras()
1145

1146
1147
1148
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

1159
    def check_health(self) -> None:
1160
1161
        if self.tokenizer:
            self.tokenizer.check_health()
1162
        self.model_executor.check_health()
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

    def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None:
        if self.tracer is None:
            return

        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
            # attribute names are based on
            # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
            seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
                                   self.model_config.model)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
                                   seq_group.request_id)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
                                   seq_group.sampling_params.temperature)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
                                   seq_group.sampling_params.top_p)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
                                   seq_group.sampling_params.max_tokens)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
                                   seq_group.sampling_params.best_of)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
                                   seq_group.sampling_params.n)
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
                                   seq_group.num_seqs())
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
                SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
                                   metrics.time_in_queue)
            seq_span.set_attribute(
                SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)