"tests/vscode:/vscode.git/clone" did not exist on "f2b8e1c5510cf3621dc4b910f0eba5289d9fee88"
llm_engine.py 75.3 KB
Newer Older
1
import os
Antoni Baum's avatar
Antoni Baum committed
2
import time
3
from collections import deque
4
from contextlib import contextmanager
5
from dataclasses import dataclass
6
from functools import partial
7
8
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
                    Iterable, List, Mapping, NamedTuple, Optional)
9
from typing import Sequence as GenericSequence
10
from typing import Set, Type, Union
11

12
import torch
13
from typing_extensions import TypeVar
14

15
import vllm.envs as envs
16
17
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
18
                         ObservabilityConfig, ParallelConfig,
19
                         PromptAdapterConfig, SchedulerConfig,
20
                         SpeculativeConfig)
21
22
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.engine.arg_utils import EngineArgs
24
from vllm.engine.metrics_types import StatLoggerBase, Stats
25
26
27
28
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
29
from vllm.executor.executor_base import ExecutorBase
30
from vllm.executor.gpu_executor import GPUExecutor
31
from vllm.executor.ray_utils import initialize_ray_cluster
32
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
33
34
                         InputRegistry, LLMInputs, PromptInputs)
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.logger import init_logger
36
from vllm.lora.request import LoRARequest
37
from vllm.model_executor.layers.sampler import SamplerOutput
38
39
40
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
41
from vllm.prompt_adapter.request import PromptAdapterRequest
42
from vllm.sampling_params import RequestOutputKind, SamplingParams
43
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
44
                           Sequence, SequenceGroup, SequenceGroupMetadata,
45
                           SequenceStatus, CompletionSequenceGroupOutput, VLLM_INVALID_TOKEN_ID)
46
47
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
48
from vllm.transformers_utils.config import try_get_generation_config
49
from vllm.transformers_utils.detokenizer import Detokenizer
50
from vllm.transformers_utils.tokenizer import AnyTokenizer
51
from vllm.transformers_utils.tokenizer_group import (
52
    BaseTokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
53
54
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
55
from vllm.utils import Counter, Device, weak_bind
56
from vllm.version import __version__ as VLLM_VERSION
57
58

logger = init_logger(__name__)
59
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
60

61

62
63
64
65
66
67
68
69
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
    config = try_get_generation_config(
        model_config.model,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.revision,
    )

    if config is None:
70
71
        return {}

72
73
    return config.to_diff_dict()

74

75
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
76
77
78
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)


79
80
81
82
83
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
84
85
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
86
87


88
89
90
91
92
93
94
95
96
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
    skip: List[int]


97
class SchedulerContext:
98

99
    def __init__(self, multi_step_stream_outputs: bool = False):
100
101
102
103
104
105
106
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
                                         EmbeddingRequestOutput]] = []
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

107
108
        self.multi_step_stream_outputs: bool = multi_step_stream_outputs

109
110
111
112
113
114
115
116
117
118
119
    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
                      is_last_step: bool):
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
                       skip=[]))
120
121


122
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
123
    """An LLM engine that receives requests and generates texts.
124

Woosuk Kwon's avatar
Woosuk Kwon committed
125
    This is the main class for the vLLM engine. It receives requests
126
127
128
129
130
131
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

132
133
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
134

135
136
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
    :ref:`engine_args`)
137
138
139
140
141
142
143

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
144
        device_config: The configuration related to the device.
145
146
147
        lora_config (Optional): The configuration related to serving multi-LoRA.
        speculative_config (Optional): The configuration related to speculative
            decoding.
148
149
        executor_class: The model executor class for managing distributed
            execution.
150
        prompt_adapter_config (Optional): The configuration related to serving
151
            prompt adapters.
152
        log_stats: Whether to log statistics.
153
        usage_context: Specified entry point, used for usage info collection.
154
    """
155

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

        return output

    @classmethod
    def validate_outputs(
zhuwenwen's avatar
zhuwenwen committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

203
204
205
206
        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

207
208
209
210
211
212
    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
213
        device_config: DeviceConfig,
214
        load_config: LoadConfig,
215
        lora_config: Optional[LoRAConfig],
216
        speculative_config: Optional[SpeculativeConfig],
217
        decoding_config: Optional[DecodingConfig],
218
        observability_config: Optional[ObservabilityConfig],
219
        prompt_adapter_config: Optional[PromptAdapterConfig],
220
        executor_class: Type[ExecutorBase],
221
        log_stats: bool,
yhu422's avatar
yhu422 committed
222
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
223
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
224
        input_registry: InputRegistry = INPUT_REGISTRY,
225
        use_cached_outputs: bool = False,
226
227
    ) -> None:
        logger.info(
228
229
230
            "Initializing an LLM engine (v%s) with config: "
            "model=%r, speculative_config=%r, tokenizer=%r, "
            "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
231
            "override_neuron_config=%s, "
232
            "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
233
234
            "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
            "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
235
            "pipeline_parallel_size=%d, "
236
237
            "disable_custom_all_reduce=%s, quantization=%s, "
            "enforce_eager=%s, kv_cache_dtype=%s, "
238
            "quantization_param_path=%s, device_config=%s, "
239
            "decoding_config=%r, observability_config=%r, "
240
            "seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
241
242
243
            "num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
            "enable_prefix_caching=%s, use_async_output_proc=%s, "
            "use_cached_outputs=%s, mm_processor_kwargs=%s)",
244
            VLLM_VERSION,
245
246
247
248
249
250
            model_config.model,
            speculative_config,
            model_config.tokenizer,
            model_config.skip_tokenizer_init,
            model_config.tokenizer_mode,
            model_config.revision,
251
            model_config.override_neuron_config,
252
            model_config.rope_scaling,
253
            model_config.rope_theta,
254
255
256
257
258
259
260
            model_config.tokenizer_revision,
            model_config.trust_remote_code,
            model_config.dtype,
            model_config.max_model_len,
            load_config.download_dir,
            load_config.load_format,
            parallel_config.tensor_parallel_size,
261
            parallel_config.pipeline_parallel_size,
262
263
264
265
266
267
268
            parallel_config.disable_custom_all_reduce,
            model_config.quantization,
            model_config.enforce_eager,
            cache_config.cache_dtype,
            model_config.quantization_param_path,
            device_config.device,
            decoding_config,
269
            observability_config,
270
            model_config.seed,
271
            model_config.served_model_name,
272
            scheduler_config.use_v2_block_manager,
273
            scheduler_config.num_scheduler_steps,
274
            scheduler_config.multi_step_stream_outputs,
275
            cache_config.enable_prefix_caching,
276
            model_config.use_async_output_proc,
277
            use_cached_outputs,
278
            model_config.mm_processor_kwargs,
279
        )
280
        # TODO(woosuk): Print more configs in debug mode.
281
282
        from vllm.plugins import load_general_plugins
        load_general_plugins()
283
284
285

        self.model_config = model_config
        self.cache_config = cache_config
286
        self.lora_config = lora_config
287
288
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
289
        self.device_config = device_config
290
        self.speculative_config = speculative_config
291
        self.load_config = load_config
292
        self.decoding_config = decoding_config or DecodingConfig()
293
        self.prompt_adapter_config = prompt_adapter_config
294
295
        self.observability_config = observability_config or ObservabilityConfig(
        )
296
        self.log_stats = log_stats
297
        self.use_cached_outputs = use_cached_outputs
298

299
        if not self.model_config.skip_tokenizer_init:
300
            self.tokenizer = self._init_tokenizer()
301
            self.detokenizer = Detokenizer(self.tokenizer)
302
            tokenizer_group = self.get_tokenizer_group()
303
304
        else:
            self.tokenizer = None
305
            self.detokenizer = None
306
307
308
309
310
311
312
313
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
314

315
        self.seq_counter = Counter()
316
317
        self.generation_config_fields = _load_generation_config_dict(
            model_config)
318

319
320
321
        self.input_preprocessor = InputPreprocessor(model_config,
                                                    self.tokenizer)

322
323
324
        self.input_registry = input_registry
        self.input_processor = input_registry.create_input_processor(
            model_config)
325

326
327
328
329
330
331
332
333
        self.model_executor = executor_class(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
334
            load_config=load_config,
335
            prompt_adapter_config=prompt_adapter_config,
336
            observability_config=self.observability_config,
337
        )
338

339
340
        if not self.model_config.embedding_mode:
            self._initialize_kv_caches()
341

yhu422's avatar
yhu422 committed
342
343
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
344
345
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
            usage_message.report_usage(
                get_architecture_class_name(model_config),
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
                    str(model_config.dtype),
                    "tensor_parallel_size":
                    parallel_config.tensor_parallel_size,
                    "block_size":
                    cache_config.block_size,
                    "gpu_memory_utilization":
                    cache_config.gpu_memory_utilization,

                    # Quantization
                    "quantization":
                    model_config.quantization,
                    "kv_cache_dtype":
364
                    str(cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
365
366
367
368

                    # Feature flags
                    "enable_lora":
                    bool(lora_config),
369
370
                    "enable_prompt_adapter":
                    bool(prompt_adapter_config),
yhu422's avatar
yhu422 committed
371
372
373
374
375
376
377
                    "enable_prefix_caching":
                    cache_config.enable_prefix_caching,
                    "enforce_eager":
                    model_config.enforce_eager,
                    "disable_custom_all_reduce":
                    parallel_config.disable_custom_all_reduce,
                })
zhuwenwen's avatar
zhuwenwen committed
378

379
380
381
382
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
383

384
385
386
387
388
389
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
390
391
            SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
                             multi_step_stream_outputs)
392
393
394
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

395
396
        if model_config.use_async_output_proc:
            process_model_outputs = weak_bind(self._process_model_outputs)
zhuwenwen's avatar
zhuwenwen committed
397
398

            self.async_callbacks = [
399
400
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
zhuwenwen's avatar
zhuwenwen committed
401
402
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
403
404
        else:
            self.async_callbacks = []
405
406
407

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
408
        self.process_request_outputs_callback: Optional[Callable] = None
409

410
        # Create the scheduler.
411
412
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
413
        self.scheduler = [
414
415
416
            Scheduler(
                scheduler_config, cache_config, lora_config,
                parallel_config.pipeline_parallel_size,
417
                self.async_callbacks[v_id]
418
                if model_config.use_async_output_proc else None)
419
            for v_id in range(parallel_config.pipeline_parallel_size)
420
        ]
zhuwenwen's avatar
zhuwenwen committed
421

422
423
        # Metric Logging.
        if self.log_stats:
424
425
426
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
427
428
429
430
431
432
433
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

434
435
436
437
438
439
440
441
442
443
444
445
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        labels=dict(model_name=model_config.served_model_name),
                        max_model_len=self.model_config.max_model_len),
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
446

447
448
449
450
451
452
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

453
454
455
456
457
458
459
460
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
461
                get_tokenizer_for_seq,
462
463
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
464
                    get_tokenizer_for_seq,
465
466
                ),
            ))
467
468
        
        self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1'
469

470
471
472
473
474
475
476
477
478
479
480
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
481
482
483
484
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
485
486
487
488
489
490
491
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

492
    @classmethod
493
494
    def _get_executor_cls(cls,
                          engine_config: EngineConfig) -> Type[ExecutorBase]:
495
496
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
497
        # Initialize the cluster and specify the executor class.
498
499
500
501
502
503
504
505
506
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
507
508
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
509
        elif engine_config.device_config.device_type == "tpu":
510
511
512
513
514
515
516
517
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutor
                executor_class = RayTPUExecutor
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutor
                executor_class = TPUExecutor
518
        elif engine_config.device_config.device_type == "cpu":
519
520
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
521
522
523
        elif engine_config.device_config.device_type == "openvino":
            from vllm.executor.openvino_executor import OpenVINOExecutor
            executor_class = OpenVINOExecutor
524
525
526
527
528
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutor
                executor_class = RayXPUExecutor
529
530
531
532
533
534
535
            elif distributed_executor_backend == "mp":
                # FIXME(kunshang):
                # spawn needs calling `if __name__ == '__main__':``
                # fork is not supported for xpu start new process.
                logger.error(
                    "Both start methods (spawn and fork) have issue "
                    "on XPU if you use mp backend, Please try ray instead.")
536
537
538
            else:
                from vllm.executor.xpu_executor import XPUExecutor
                executor_class = XPUExecutor
539
        elif distributed_executor_backend == "ray":
540
            initialize_ray_cluster(engine_config.parallel_config)
541
542
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
543
544
545
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutor)
546
547
548
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
549
            executor_class = MultiprocessingGPUExecutor
550
551
552
        else:
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor
553
554
555
556
557
558
559
560
561
562
563
564
565
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()
        executor_class = cls._get_executor_cls(engine_config)
566
        # Create the LLM engine.
yhu422's avatar
yhu422 committed
567
        engine = cls(
568
            **engine_config.to_dict(),
yhu422's avatar
yhu422 committed
569
570
571
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
572
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
573
        )
574

575
        return engine
576

577
578
579
580
581
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

582
583
584
585
586
587
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

588
    def get_tokenizer_group(
589
590
591
592
593
594
        self,
        group_type: Type[_G] = BaseTokenizerGroup,
    ) -> _G:
        tokenizer_group = self.tokenizer

        if tokenizer_group is None:
595
596
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
597
598
599
600
        if not isinstance(tokenizer_group, group_type):
            raise TypeError("Invalid type of tokenizer group. "
                            f"Expected type: {group_type}, but "
                            f"found type: {type(tokenizer_group)}")
601

602
        return tokenizer_group
603

604
    def get_tokenizer(
605
606
607
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
608
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
609

610
611
612
613
614
615
    def _init_tokenizer(self) -> BaseTokenizerGroup:
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
            parallel_config=self.parallel_config,
            enable_lora=bool(self.lora_config))
616

617
618
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
619
        self.cache_config.verify_with_parallel_config(self.parallel_config)
620
621
622
623
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
624
625
626
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
627

628
629
630
    def _add_processed_request(
        self,
        request_id: str,
631
        processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
632
633
634
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
635
        prompt_adapter_request: Optional[PromptAdapterRequest],
636
        trace_headers: Optional[Mapping[str, str]] = None,
637
        priority: int = 0,
638
    ) -> None:
639
        self._validate_model_inputs(processed_inputs)
640
641
642
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
643
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
644
645

        seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
646
                       lora_request, prompt_adapter_request)
647

648
649
650
651
652
653
654
655
656
657
        encoder_seq = None
        if 'encoder_prompt_token_ids' in processed_inputs:
            encoder_seq = Sequence(seq_id,
                                   processed_inputs,
                                   block_size,
                                   eos_token_id,
                                   lora_request,
                                   prompt_adapter_request,
                                   from_decoder_prompt=False)

658
659
660
661
662
663
664
665
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
666
                trace_headers=trace_headers,
667
                prompt_adapter_request=prompt_adapter_request,
668
669
                encoder_seq=encoder_seq,
                priority=priority)
670
671
672
673
674
675
676
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
677
                prompt_adapter_request=prompt_adapter_request,
678
679
                encoder_seq=encoder_seq,
                priority=priority)
680
681
682
683
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

684
685
686
687
688
689
690
691
692
693
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
694

695
696
697
    def add_request(
        self,
        request_id: str,
698
        inputs: PromptInputs,
699
        params: Union[SamplingParams, PoolingParams],
700
        arrival_time: Optional[float] = None,
701
        lora_request: Optional[LoRARequest] = None,
702
        trace_headers: Optional[Mapping[str, str]] = None,
703
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
704
        priority: int = 0,
705
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
706
        """Add a request to the engine's request pool.
707
708

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
709
        scheduler as `engine.step()` is called. The exact scheduling policy is
710
711
712
713
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
714
715
716
717
718
719
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
720
            arrival_time: The arrival time of the request. If None, we use
721
                the current monotonic time.
722
            trace_headers: OpenTelemetry trace headers.
723
724
            priority: The priority of the request.
                Only applicable with priority scheduling.
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
749
        """
750
751
752
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
753
754
755
756
757

        if priority > 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

758
        if arrival_time is None:
759
            arrival_time = time.time()
760

761
        preprocessed_inputs = self.input_preprocessor.preprocess(
762
            inputs,
763
764
            request_id=request_id,
            lora_request=lora_request,
765
766
            prompt_adapter_request=prompt_adapter_request,
        )
767
        processed_inputs = self.input_processor(preprocessed_inputs)
768

769
770
771
772
773
774
        self._add_processed_request(
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
775
            prompt_adapter_request=prompt_adapter_request,
776
            trace_headers=trace_headers,
777
            priority=priority,
778
        )
779
780
781
782
783
784

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
785
786
        arrival_time: float,
        lora_request: Optional[LoRARequest],
787
        trace_headers: Optional[Mapping[str, str]] = None,
788
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
789
        encoder_seq: Optional[Sequence] = None,
790
        priority: int = 0,
791
792
793
794
795
796
797
798
799
800
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

801
802
803
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
804

805
        sampling_params.update_from_generation_config(
806
            self.generation_config_fields, seq.eos_token_id)
807

808
        # Create the sequence group.
809
810
811
812
813
814
815
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
816
            prompt_adapter_request=prompt_adapter_request,
817
818
            encoder_seq=encoder_seq,
            priority=priority)
819

820
821
822
823
824
825
826
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
827
828
        arrival_time: float,
        lora_request: Optional[LoRARequest],
829
        prompt_adapter_request: Optional[PromptAdapterRequest],
830
        encoder_seq: Optional[Sequence] = None,
831
        priority: int = 0,
832
833
834
835
836
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
837
838
839
840
841
842
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
843
            prompt_adapter_request=prompt_adapter_request,
844
845
            encoder_seq=encoder_seq,
            priority=priority)
846
        return seq_group
847

Antoni Baum's avatar
Antoni Baum committed
848
849
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
850
851

        Args:
Antoni Baum's avatar
Antoni Baum committed
852
            request_id: The ID(s) of the request to abort.
853
854
855
856
857
858
859
860
861
862
863

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
864
        """
865
866
        for scheduler in self.scheduler:
            scheduler.abort_seq_group(request_id)
867

868
869
870
871
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

872
873
874
875
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

876
877
878
879
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

880
881
882
883
884
885
886
887
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

888
    def get_num_unfinished_requests(self) -> int:
889
        """Gets the number of unfinished requests."""
890
891
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
892

893
    def has_unfinished_requests(self) -> bool:
894
        """Returns True if there are unfinished requests."""
895
896
897
898
899
900
901
902
903
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
904

905
    @staticmethod
906
907
908
909
910
911
912
913
914
915
916
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
        outputs: List[EmbeddingSequenceGroupOutput],
    ) -> None:
        seq_group.embeddings = outputs[0].embeddings

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

917
918
919
920
921
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
922

923
924
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
925

926
        """
927
        now = time.time()
928

929
        if len(ctx.output_queue) == 0:
930
931
            return None

932
        # Get pending async postprocessor
933
934
935
936
937
938
939
940
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
             is_last_step, skip) = ctx.output_queue[0]
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
             is_last_step, skip) = ctx.output_queue.popleft()
941
942
943
944
945
946
947
948
949
950

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

        # Organize outputs by [step][sequence group] instead of
        # [sequence group][step].
        if len(outputs) > 1:
            outputs_by_sequence_group = create_output_by_sequence_group(
                outputs, num_seq_groups=len(seq_group_metadata_list))
951
        elif len(outputs) == 1:
952
            outputs_by_sequence_group = outputs
953
954
        else:
            return None
955

956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

973
        finished_before: List[int] = []
974
        finished_now: List[int] = []
975
        empty_seq_indices: List[int] = []
976
977
978
979
980
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
981
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
982

983
            seq_group = scheduled_seq_group.seq_group
984
985
986
987
988
989
990
991
992
993

            if seq_group.is_finished():
                finished_before.append(i)
                continue

            if len(outputs) > 1:
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

994
            # tree style speculative decoding may generate empty output in first step
995
996
997
998
999
1000
1001
1002
1003
            if self.tree_decoding and outputs and isinstance(output[0], CompletionSequenceGroupOutput):
                samples = [o.samples[0] for o in output]
                valid_samples = [
                    sample for sample in samples
                    if sample.output_token != VLLM_INVALID_TOKEN_ID
                ]
                if len(valid_samples) == 0:
                    empty_seq_indices.append(i)
                    continue
1004

1005
1006
1007
1008
1009
1010
            if not is_async:
                seq_group.update_num_computed_tokens(
                    scheduled_seq_group.token_chunk_size)

            if outputs:
                for o in outputs:
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
                                o.model_forward_time)
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
                                o.model_execute_time)
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1025

1026
            if self.model_config.embedding_mode:
1027
                self._process_sequence_group_outputs(seq_group, output)
1028
1029
1030
1031
1032
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
                    self.output_processor.process_outputs(
                        seq_group, output, is_async)
1033

1034
1035
            if seq_group.is_finished():
                finished_now.append(i)
1036

1037
1038
1039
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1040

1041
            seq_group = scheduled_seq_group.seq_group
1042
            seq_group.maybe_set_first_token_time(now)
1043
1044
            request_output = RequestOutputFactory.create(
                seq_group, use_cache=self.use_cached_outputs)
1045
1046
            if request_output:
                ctx.request_outputs.append(request_output)
1047

1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1060
1061
1062
1063
1064
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

1065
1066
        # For multi-step without streaming, don't create outputs each iteration
        if not is_last_step and not ctx.multi_step_stream_outputs:
1067
1068
1069
1070
            # Immediately process request outputs here (if callback is given)
            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
1071
                ctx.request_outputs.clear()
1072
1073
1074
            return

        # Create the outputs
1075
        for i in indices:
1076
            if i in skip or i in finished_before or i in finished_now or i in empty_seq_indices:
1077
1078
                continue  # Avoids double processing

1079
1080
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1081
            seq_group = scheduled_seq_group.seq_group
1082
            seq_group.maybe_set_first_token_time(now)
1083
1084
            request_output = RequestOutputFactory.create(
                seq_group, use_cache=self.use_cached_outputs)
1085
            if request_output:
1086
                ctx.request_outputs.append(request_output)
1087

1088
1089
1090
1091
1092
1093
1094
1095
        # For multi-step with streaming, create outputs each iteration
        if not is_last_step and ctx.multi_step_stream_outputs:
            # Immediately process request outputs here (if callback is given)
            if self.process_request_outputs_callback is not None:
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1096
        for seq_group in scheduler_outputs.ignored_seq_groups:
1097
1098
1099
1100
1101
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1102
1103
            request_output = RequestOutputFactory.create(
                seq_group, use_cache=self.use_cached_outputs)
1104
1105
            if request_output:
                ctx.request_outputs.append(request_output)
1106

1107
1108
1109
1110
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1111
            ctx.request_outputs.clear()
1112

1113
1114
1115
1116
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1117
            # Log stats.
1118
1119
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153

            # Tracing
            self.do_tracing(scheduler_outputs)

        return None

    def _advance_to_next_step(
            self, output: List[SamplerOutput],
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

            seq_group.update_num_computed_tokens(
                seq_group_metadata.token_chunk_size)

            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
                    " (i.e sampling_params.n == 1 and no "
                    "sampling_params.best_of > 1)")
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
                seq.append_token_id(sample.output_token, sample.logprobs)
1154

1155
    def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1156
1157
        """Performs one decoding iteration and returns newly generated results.

1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

1173
            - Step 2: Calls the distributed executor to execute the model.
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
1195
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
1196
1197
1198
1199
1200
1201
1202
1203
1204
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
1205
        """
1206
1207
1208
1209
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1210

1211
        # For llm_engine, there is no pipeline parallel support, so the engine
1212
        # used is always 0.
1213
1214
        virtual_engine = 0

1215
1216
        # These are cached outputs from previous iterations. None if on first
        # iteration
1217
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1218
1219
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1220
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1221

1222
1223
        ctx = self.scheduler_contexts[virtual_engine]

1224
1225
1226
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1227
1228
1229
1230
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
1231
            # Schedule iteration
1232
            (seq_group_metadata_list, scheduler_outputs,
1233
1234
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1235

1236
1237
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1238

1239
1240
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1241
                self._process_model_outputs(ctx=ctx)
1242

1243
1244
1245
1246
1247
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
1248
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
1249
                    allow_async_output_proc)
1250
1251
1252

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1253

1254
        if not scheduler_outputs.is_empty():
1255
            finished_requests_ids = self.scheduler[
1256
                virtual_engine].get_and_reset_finished_requests_ids()
1257
1258
1259
1260
1261
1262

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1263
                self._get_last_sampled_token_ids(virtual_engine)
1264

1265
            execute_model_req = ExecuteModelRequest(
1266
1267
1268
1269
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1270
1271
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1272
1273
1274
1275
1276
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1277
            if allow_async_output_proc:
1278
1279
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1280

1281
            outputs = self.model_executor.execute_model(
1282
                execute_model_req=execute_model_req)
1283

1284
            # We need to do this here so that last step's sampled_token_ids can
1285
1286
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
1287
                self._update_cached_scheduler_output(virtual_engine, outputs)
1288
        else:
1289
1290
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1291
1292
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1293
            # No outputs in this case
1294
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1295

1296
1297
1298
1299
1300
1301
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
1302
            # clear the cache if we have finished all the steps.
1303
1304
1305
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()

1306
            # Add results to the output_queue
1307
1308
1309
1310
1311
1312
1313
1314
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
                              is_last_step=True)

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1315
                    "Async postprocessor expects only a single output set")
1316

1317
                self._advance_to_next_step(
1318
                    outputs[0], seq_group_metadata_list,
1319
                    scheduler_outputs.scheduled_seq_groups)
1320

1321
            # Check if need to run the usual non-async path
1322
            if not allow_async_output_proc:
1323
                self._process_model_outputs(ctx=ctx)
1324

1325
                # Log stats.
1326
                self.do_log_stats(scheduler_outputs, outputs)
1327

1328
1329
1330
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1331
            # Multi-step case
1332
            return ctx.request_outputs
1333

1334
        if not self.has_unfinished_requests():
1335
1336
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1337
                self._process_model_outputs(ctx=ctx)
1338
            assert len(ctx.output_queue) == 0
1339

1340
1341
1342
1343
1344
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1345
            logger.debug("Stopping remote worker execution loop.")
1346
1347
            self.model_executor.stop_remote_worker_execution_loop()

1348
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1349

1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
            raise AssertionError(("All running sequence groups should "
                                  "have the same remaining steps."))

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1373
1374
1375
1376
1377
1378
1379
1380
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None
Antoni Baum's avatar
Antoni Baum committed
1405

1406
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1407
1408
1409
1410
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1411
1412
1413
1414
1415
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1416
1417
1418
1419
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1420
1421
1422
1423
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1424
1425
1426
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1427
1428
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1429
1430
        """Forced log when no requests active."""
        if self.log_stats:
1431
            stats = self._get_stats(scheduler_outputs, model_output,
1432
                                    finished_before, skip)
1433
            for logger in self.stat_loggers.values():
1434
                logger.log(stats)
1435

1436
1437
1438
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1439
1440
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1441
1442
1443
1444
1445
1446
1447
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1448
1449
1450
1451
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1452
        """
1453
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1454

1455
1456
        # System State
        #   Scheduler State
1457
1458
1459
1460
1461
1462
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1463
1464

        # KV Cache Usage in %
1465
        num_total_gpu = self.cache_config.num_gpu_blocks
1466
1467
        gpu_cache_usage_sys = 0.
        if num_total_gpu is not None:
1468
1469
1470
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1471
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1472

1473
        num_total_cpu = self.cache_config.num_cpu_blocks
1474
        cpu_cache_usage_sys = 0.
1475
        if num_total_cpu is not None and num_total_cpu > 0:
1476
1477
1478
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1479
1480
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1481
1482
1483
1484
1485
1486
1487
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1488
1489
1490
1491
1492
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1493
1494
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        best_of_requests: List[int] = []
        n_requests: List[int] = []
        finished_reason_requests: List[str] = []

        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1508
        if scheduler_outputs is not None:
1509
1510
1511
1512
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1513
            num_generation_tokens_from_prefill_groups = 0.
1514
1515
1516
1517
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1518
1519
1520

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1521
1522
1523
1524
1525
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue

1526
1527
1528
1529
                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1530

1531
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1532
                seq_group = scheduled_seq_group.seq_group
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
                        latency = seq_group.get_last_latency(now)
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
                    latency = seq_group.get_last_latency(now)
                    time_per_output_tokens_iter.append(latency)

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1561
                if seq_group.is_finished():
1562
                    # Latency timings
1563
1564
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
1565
1566
1567
1568
1569
1570
1571
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
1572
1573
1574
1575
                    if seq_group.sampling_params is not None:
                        best_of_requests.append(
                            seq_group.sampling_params.best_of)
                        n_requests.append(seq_group.sampling_params.n)
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1588
                actual_num_batched_tokens - num_prompt_tokens_iter +
1589
                num_generation_tokens_from_prefill_groups)
1590

1591
1592
1593
1594
1595
1596
1597
1598
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
        if model_output and (model_output[0].spec_decode_worker_metrics
                             is not None):
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1599
1600
        return Stats(
            now=now,
1601
1602
1603
1604
1605
1606
1607
1608
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1609
1610
1611
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1612
1613
1614
1615
1616
1617

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1618
            spec_decode_metrics=spec_decode_metrics,
1619
            num_preemption_iter=num_preemption_iter,
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
            best_of_requests=best_of_requests,
            n_requests=n_requests,
            finished_reason_requests=finished_reason_requests,
1630
1631
        )

1632
    def add_lora(self, lora_request: LoRARequest) -> bool:
1633
        return self.model_executor.add_lora(lora_request)
1634
1635

    def remove_lora(self, lora_id: int) -> bool:
1636
        return self.model_executor.remove_lora(lora_id)
1637

1638
    def list_loras(self) -> Set[int]:
1639
        return self.model_executor.list_loras()
1640

1641
1642
1643
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

1654
    def check_health(self) -> None:
1655
1656
        if self.tokenizer:
            self.tokenizer.check_health()
1657
        self.model_executor.check_health()
1658

1659
    def start_profile(self) -> None:
1660
1661
        # using type instead of isinstance to check to avoid capturing
        # inherited classes (MultiprocessingGPUExecutor)
1662
        if type(self.model_executor) == GPUExecutor:  # noqa: E721
1663
1664
1665
            self.model_executor.start_profile()
        else:
            self.model_executor._run_workers("start_profile")
1666
1667

    def stop_profile(self) -> None:
1668
1669
        # using type instead of isinstance to check to avoid capturing
        # inherited classes (MultiprocessingGPUExecutor)
1670
        if type(self.model_executor) == GPUExecutor:  # noqa: E721
1671
1672
1673
            self.model_executor.stop_profile()
        else:
            self.model_executor._run_workers("stop_profile")
1674

1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

    def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None:
        if self.tracer is None:
            return

        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
            # attribute names are based on
            # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
            seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
                                   self.model_config.model)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
                                   seq_group.request_id)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
                                   seq_group.sampling_params.temperature)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
                                   seq_group.sampling_params.top_p)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
                                   seq_group.sampling_params.max_tokens)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
                                   seq_group.sampling_params.best_of)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
                                   seq_group.sampling_params.n)
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
                                   seq_group.num_seqs())
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
                SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
                                   metrics.time_in_queue)
            seq_span.set_attribute(
                SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER,
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD,
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
                    metrics.model_execute_time)
1745
1746

    def is_encoder_decoder_model(self):
1747
        return self.input_preprocessor.is_encoder_decoder_model()
1748
1749

    def is_embedding_model(self):
1750
        return self.model_config.is_embedding_model
1751
1752
1753

    def _validate_model_inputs(self, inputs: Union[LLMInputs,
                                                   EncoderDecoderLLMInputs]):
1754
1755
1756
1757
1758
        if self.model_config.is_multimodal_model:
            # For encoder-decoder multimodal models, the max_prompt_len
            # restricts the decoder prompt length
            prompt_ids = inputs.get("prompt_token_ids")
        elif self.is_encoder_decoder_model():
1759
1760
1761
1762
1763
            prompt_ids = inputs.get("encoder_prompt_token_ids")
        else:
            prompt_ids = inputs.get("prompt_token_ids")

        if prompt_ids is None or len(prompt_ids) == 0:
1764
            raise ValueError("Prompt cannot be empty")
1765

1766
        if self.model_config.is_multimodal_model:
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
            max_prompt_len = self.model_config.max_model_len

            if len(prompt_ids) > max_prompt_len:
                raise ValueError(
                    f"The prompt (total length {len(prompt_ids)}) is too long "
                    f"to fit into the model (context length {max_prompt_len}). "
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
1777
1778
1779

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
zhuwenwen's avatar
zhuwenwen committed
1780
            # max_batch_len = self.scheduler_config.max_num_batched_tokens