llm_engine.py 71.5 KB
Newer Older
Antoni Baum's avatar
Antoni Baum committed
1
import time
2
from contextlib import contextmanager
3
from dataclasses import dataclass
4
5
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
                    Mapping, Optional)
6
from typing import Sequence as GenericSequence
7
from typing import Set, Tuple, Type, Union
8

9
import torch
10
from typing_extensions import TypeVar, assert_never
11

12
import vllm.envs as envs
13
14
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
15
                         ObservabilityConfig, ParallelConfig,
16
                         PromptAdapterConfig, SchedulerConfig,
17
                         SpeculativeConfig)
18
19
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
20
from vllm.engine.arg_utils import EngineArgs
21
from vllm.engine.metrics_types import StatLoggerBase, Stats
22
23
24
25
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
26
from vllm.executor.executor_base import ExecutorBase
27
from vllm.executor.ray_utils import initialize_ray_cluster
28
29
30
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
                         InputRegistry, LLMInputs, PromptInputs,
                         SingletonPromptInputs)
31
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.logger import init_logger
33
from vllm.lora.request import LoRARequest
34
from vllm.multimodal import MultiModalDataDict
35
36
37
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
38
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.sampling_params import SamplingParams
40
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
41
42
                           PoolerOutput, SamplerOutput, Sequence,
                           SequenceGroup, SequenceGroupMetadata,
43
                           SequenceStatus)
44
45
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
46
from vllm.transformers_utils.config import try_get_generation_config
47
from vllm.transformers_utils.detokenizer import Detokenizer
48
from vllm.transformers_utils.tokenizer import AnyTokenizer
49
from vllm.transformers_utils.tokenizer_group import (
50
    BaseTokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
51
52
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
53
from vllm.utils import Counter, Device
54
from vllm.version import __version__ as VLLM_VERSION
55
56

logger = init_logger(__name__)
57
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
58

59

60
61
62
63
64
65
66
67
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
    config = try_get_generation_config(
        model_config.model,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.revision,
    )

    if config is None:
68
69
        return {}

70
71
    return config.to_diff_dict()

72

73
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
74
75
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)

76
77
78
79
80
PromptComponents = Tuple[Optional[str], List[int],
                         Optional[MultiModalDataDict]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
                                Optional[MultiModalDataDict]]

81

82
83
84
85
86
87
88
89
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    last_output: Optional[SamplerOutput] = None
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None


90
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
91
    """An LLM engine that receives requests and generates texts.
92

Woosuk Kwon's avatar
Woosuk Kwon committed
93
    This is the main class for the vLLM engine. It receives requests
94
95
96
97
98
99
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

100
101
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
102

103
104
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
    :ref:`engine_args`)
105
106
107
108
109
110
111

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
112
        device_config: The configuration related to the device.
113
114
115
        lora_config (Optional): The configuration related to serving multi-LoRA.
        speculative_config (Optional): The configuration related to speculative
            decoding.
116
117
        executor_class: The model executor class for managing distributed
            execution.
118
119
        prompt_adapter_config (Optional): The configuration related to serving 
            prompt adapters.
120
        log_stats: Whether to log statistics.
121
        usage_context: Specified entry point, used for usage info collection.
122
    """
123

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

        return output

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

175
176
177
178
179
180
    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
181
        device_config: DeviceConfig,
182
        load_config: LoadConfig,
183
        lora_config: Optional[LoRAConfig],
184
        speculative_config: Optional[SpeculativeConfig],
185
        decoding_config: Optional[DecodingConfig],
186
        observability_config: Optional[ObservabilityConfig],
187
        prompt_adapter_config: Optional[PromptAdapterConfig],
188
        executor_class: Type[ExecutorBase],
189
        log_stats: bool,
yhu422's avatar
yhu422 committed
190
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
191
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
192
        input_registry: InputRegistry = INPUT_REGISTRY,
193
194
    ) -> None:
        logger.info(
195
196
197
            "Initializing an LLM engine (v%s) with config: "
            "model=%r, speculative_config=%r, tokenizer=%r, "
            "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
198
            "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
199
200
            "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
            "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
201
            "pipeline_parallel_size=%d, "
202
203
            "disable_custom_all_reduce=%s, quantization=%s, "
            "enforce_eager=%s, kv_cache_dtype=%s, "
204
            "quantization_param_path=%s, device_config=%s, "
205
            "decoding_config=%r, observability_config=%r, "
206
            "seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
207
            "num_scheduler_steps=%d, enable_prefix_caching=%s)",
208
            VLLM_VERSION,
209
210
211
212
213
214
            model_config.model,
            speculative_config,
            model_config.tokenizer,
            model_config.skip_tokenizer_init,
            model_config.tokenizer_mode,
            model_config.revision,
215
            model_config.rope_scaling,
216
            model_config.rope_theta,
217
218
219
220
221
222
223
            model_config.tokenizer_revision,
            model_config.trust_remote_code,
            model_config.dtype,
            model_config.max_model_len,
            load_config.download_dir,
            load_config.load_format,
            parallel_config.tensor_parallel_size,
224
            parallel_config.pipeline_parallel_size,
225
226
227
228
229
230
231
            parallel_config.disable_custom_all_reduce,
            model_config.quantization,
            model_config.enforce_eager,
            cache_config.cache_dtype,
            model_config.quantization_param_path,
            device_config.device,
            decoding_config,
232
            observability_config,
233
            model_config.seed,
234
            model_config.served_model_name,
235
            scheduler_config.use_v2_block_manager,
236
            scheduler_config.num_scheduler_steps,
237
            cache_config.enable_prefix_caching,
238
        )
239
        # TODO(woosuk): Print more configs in debug mode.
240
241
242
        from vllm.plugins import load_general_plugins
        load_general_plugins()

243
244
        self.model_config = model_config
        self.cache_config = cache_config
245
        self.lora_config = lora_config
246
247
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
248
        self.device_config = device_config
249
        self.speculative_config = speculative_config
250
        self.load_config = load_config
251
        self.decoding_config = decoding_config or DecodingConfig()
252
        self.prompt_adapter_config = prompt_adapter_config
253
254
        self.observability_config = observability_config or ObservabilityConfig(
        )
255
256
        self.log_stats = log_stats

257
        if not self.model_config.skip_tokenizer_init:
258
            self.tokenizer = self._init_tokenizer()
259
            self.detokenizer = Detokenizer(self.tokenizer)
260
            tokenizer_group = self.get_tokenizer_group()
261
262
        else:
            self.tokenizer = None
263
            self.detokenizer = None
264
265
266
267
268
269
270
271
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
272

273
        self.seq_counter = Counter()
274
275
        self.generation_config_fields = _load_generation_config_dict(
            model_config)
276

277
278
279
        self.input_registry = input_registry
        self.input_processor = input_registry.create_input_processor(
            model_config)
280

281
282
283
284
285
286
287
288
        self.model_executor = executor_class(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
289
            load_config=load_config,
290
            prompt_adapter_config=prompt_adapter_config,
291
            observability_config=self.observability_config,
292
        )
293

294
295
        if not self.model_config.embedding_mode:
            self._initialize_kv_caches()
296

yhu422's avatar
yhu422 committed
297
298
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
299
300
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
            usage_message.report_usage(
                get_architecture_class_name(model_config),
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
                    str(model_config.dtype),
                    "tensor_parallel_size":
                    parallel_config.tensor_parallel_size,
                    "block_size":
                    cache_config.block_size,
                    "gpu_memory_utilization":
                    cache_config.gpu_memory_utilization,

                    # Quantization
                    "quantization":
                    model_config.quantization,
                    "kv_cache_dtype":
319
                    str(cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
320
321
322
323

                    # Feature flags
                    "enable_lora":
                    bool(lora_config),
324
325
                    "enable_prompt_adapter":
                    bool(prompt_adapter_config),
yhu422's avatar
yhu422 committed
326
327
328
329
330
331
332
333
                    "enable_prefix_caching":
                    cache_config.enable_prefix_caching,
                    "enforce_eager":
                    model_config.enforce_eager,
                    "disable_custom_all_reduce":
                    parallel_config.disable_custom_all_reduce,
                })

334
335
336
337
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
338

339
        # Create the scheduler.
340
341
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
342
343
344
345
346
        self.scheduler = [
            Scheduler(scheduler_config, cache_config, lora_config,
                      parallel_config.pipeline_parallel_size)
            for _ in range(parallel_config.pipeline_parallel_size)
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
347

348
349
        # Metric Logging.
        if self.log_stats:
350
351
352
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
353
354
355
356
357
358
359
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

360
361
362
363
364
365
366
367
368
369
370
371
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        labels=dict(model_name=model_config.served_model_name),
                        max_model_len=self.model_config.max_model_len),
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
372

373
374
375
376
377
378
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

379
380
381
382
383
384
385
386
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
387
                get_tokenizer_for_seq,
388
389
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
390
                    get_tokenizer_for_seq,
391
392
393
                ),
            ))

394
395
396
397
398
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

399
400
401
402
403
404
405
406
407
408
409
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
410
411
412
413
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
414
415
416
417
418
419
420
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

421
    @classmethod
422
423
    def _get_executor_cls(cls,
                          engine_config: EngineConfig) -> Type[ExecutorBase]:
424
425
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
426
        # Initialize the cluster and specify the executor class.
427
428
429
430
431
432
433
434
435
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
436
437
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
438
        elif engine_config.device_config.device_type == "tpu":
439
440
441
442
443
444
445
446
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutor
                executor_class = RayTPUExecutor
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutor
                executor_class = TPUExecutor
447
        elif engine_config.device_config.device_type == "cpu":
448
449
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
450
451
452
        elif engine_config.device_config.device_type == "openvino":
            from vllm.executor.openvino_executor import OpenVINOExecutor
            executor_class = OpenVINOExecutor
453
454
455
456
457
458
459
460
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutor
                executor_class = RayXPUExecutor
            else:
                from vllm.executor.xpu_executor import XPUExecutor
                executor_class = XPUExecutor
461
        elif distributed_executor_backend == "ray":
462
            initialize_ray_cluster(engine_config.parallel_config)
463
464
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
465
466
467
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutor)
468
469
470
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
471
            executor_class = MultiprocessingGPUExecutor
472
473
474
        else:
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor
475
476
477
478
479
480
481
482
483
484
485
486
487
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()
        executor_class = cls._get_executor_cls(engine_config)
488
        # Create the LLM engine.
yhu422's avatar
yhu422 committed
489
        engine = cls(
490
            **engine_config.to_dict(),
yhu422's avatar
yhu422 committed
491
492
493
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
494
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
495
        )
496

497
        return engine
498

499
500
501
502
503
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

504
505
506
507
508
509
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

510
511
512
513
    MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
                                   "skip_tokenizer_init is True")

    def get_tokenizer_group(
514
515
516
517
518
519
520
521
522
523
524
525
526
        self,
        group_type: Type[_G] = BaseTokenizerGroup,
        *,
        missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
    ) -> _G:
        tokenizer_group = self.tokenizer

        if tokenizer_group is None:
            raise ValueError(missing_msg)
        if not isinstance(tokenizer_group, group_type):
            raise TypeError("Invalid type of tokenizer group. "
                            f"Expected type: {group_type}, but "
                            f"found type: {type(tokenizer_group)}")
527

528
        return tokenizer_group
529

530
    def get_tokenizer(
531
532
533
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
534
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
535

536
537
538
539
540
541
    def _init_tokenizer(self) -> BaseTokenizerGroup:
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
            parallel_config=self.parallel_config,
            enable_lora=bool(self.lora_config))
542

543
544
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
545
        self.cache_config.verify_with_parallel_config(self.parallel_config)
546
547
548
549
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
550
551
552
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
553

554
555
556
557
558
559
560
561
562
563
564
565
566
    def _get_bos_token_id(self,
                          lora_request: Optional[LoRARequest] = None
                          ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for BOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id

    def _get_eos_token_id(self,
                          lora_request: Optional[LoRARequest] = None
                          ) -> Optional[int]:
567
568
569
570
571
572
573
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

574
    def _get_decoder_start_token_id(self) -> Optional[int]:
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        '''
        Obtain the decoder start token id employed by an encoder/decoder
        model. Returns None for non-encoder/decoder models or if the
        model config is unavailable.
        '''

        if not self.is_encoder_decoder_model():
            logger.warning("Using None for decoder start token id because "
                           "this is not an encoder/decoder model.")
            return None

        if (self.model_config is None or self.model_config.hf_config is None):
            logger.warning("Using None for decoder start token id because "
                           "model config is not available.")
            return None

        dec_start_token_id = getattr(self.model_config.hf_config,
                                     'decoder_start_token_id', None)
        if dec_start_token_id is None:
            logger.warning("Falling back on <BOS> for decoder start token id "
                           "because decoder start token id is not available.")
            dec_start_token_id = self._get_bos_token_id()

        return dec_start_token_id

600
601
602
    def _add_processed_request(
        self,
        request_id: str,
603
        processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
604
605
606
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
607
        prompt_adapter_request: Optional[PromptAdapterRequest],
608
        trace_headers: Optional[Mapping[str, str]] = None,
609
    ) -> None:
610
        self._validate_model_inputs(processed_inputs)
611
612
613
614
615
616
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
        eos_token_id = self._get_eos_token_id(lora_request)

        seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
617
                       lora_request, prompt_adapter_request)
618

619
620
621
622
623
624
625
626
627
628
        encoder_seq = None
        if 'encoder_prompt_token_ids' in processed_inputs:
            encoder_seq = Sequence(seq_id,
                                   processed_inputs,
                                   block_size,
                                   eos_token_id,
                                   lora_request,
                                   prompt_adapter_request,
                                   from_decoder_prompt=False)

629
630
631
632
633
634
635
636
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
637
                trace_headers=trace_headers,
638
639
                prompt_adapter_request=prompt_adapter_request,
                encoder_seq=encoder_seq)
640
641
642
643
644
645
646
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
647
648
                prompt_adapter_request=prompt_adapter_request,
                encoder_seq=encoder_seq)
649
650
651
652
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

653
654
655
656
657
658
659
660
661
662
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
663

664
    _LLMInputComponentsType = Tuple[str, List[int]]
665
666
667

    def _prepare_decoder_input_ids_for_generation(
        self,
668
        decoder_input_ids: Optional[List[int]],
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
    ) -> List[int]:
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

        Based on

        https://github.com/huggingface/transformers/blob/
        4037a2b5b1278736e566aec12e169100275545ea/
        src/transformers/generation/utils.py

        specifically GenerationMixin._prepare_decoder_input_ids_for_generation()

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """

690
        decoder_start_token_id = self._get_decoder_start_token_id()
691
692
693
694
695
        assert decoder_start_token_id is not None

        if decoder_input_ids is None:
            # no decoder prompt input ->
            # use decoder_start_token_id as decoder_input_ids
696
            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
697
698
699
700
701
702
703
704
705
706

        if (len(decoder_input_ids) == 0
                or decoder_input_ids[0] != decoder_start_token_id):
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

    def _tokenize_prompt(
        self,
        prompt: str,
707
708
        request_id: str,
        lora_request: Optional[LoRARequest],
709
710
    ) -> List[int]:
        '''
711
        Wrapper around application of the model's tokenizer.
712
713
714
715
716
717
718
719
720
721
722
723

        Arguments:

        * prompt
        * request_id
        * lora_request

        Returns:

        * prompt token ids
        '''

724
725
        tokenizer = self.get_tokenizer_group(
            missing_msg="prompts must be None if skip_tokenizer_init is True")
726

727
728
729
        return tokenizer.encode(request_id=request_id,
                                prompt=prompt,
                                lora_request=lora_request)
730

731
    def _extract_prompt_components(
732
        self,
733
734
735
736
        inputs: SingletonPromptInputs,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
    ) -> PromptComponents:
737
        '''
738
        Extract the components of any single encoder or decoder input prompt.
739
740
741
742
743

        Arguments:

        * request_id
        * inputs: single encoder or decoder input prompt
744
        * lora_request: this is only valid for decoder prompts
745
746
747
748
749

        Returns:

        * prompt
        * prompt_token_ids
750
        * multi_modal_data
751
752
        '''

753
        if isinstance(inputs, str):
754
755
756
757
            prompt = inputs
            prompt_token_ids = self._tokenize_prompt(
                prompt,
                request_id=request_id,
758
                lora_request=lora_request,
759
            )
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
            multi_modal_data = None
        elif isinstance(inputs, dict):
            if "prompt_token_ids" in inputs:
                prompt = None
                prompt_token_ids = inputs["prompt_token_ids"]
            else:
                # NOTE: This extra assignment is required to pass mypy
                prompt = parsed_prompt = inputs["prompt"]
                prompt_token_ids = self._tokenize_prompt(
                    parsed_prompt,
                    request_id=request_id,
                    lora_request=lora_request,
                )

            multi_modal_data = inputs.get("multi_modal_data")
775
        else:
776
            assert_never(inputs)
777

778
        return prompt, prompt_token_ids, multi_modal_data
779

780
781
782
783
784
785
786
787
788
    def _apply_prompt_adapter(
        self,
        prompt_token_ids: List[int],
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> List[int]:
        if prompt_adapter_request:
            prompt_token_ids = (
                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
                + prompt_token_ids)
789

790
        return prompt_token_ids
791

792
    def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
        '''
        Specifically for encoder/decoder models:
        generate a default decoder prompt for when
        the user specifies only the encoder prompt.

        Encoder/decoder models utilize the decoder
        prompt in different ways; as new models are
        added, it is intended that this function
        will be extended to produce differing
        default decoder prompts, depending on the
        model variety.

        Absent a special case, the default behavior
        of this method is to mirror the behavior of
        the HuggingFace (HF) GenerationMixin for a None
        decoder prompt, which is to employ a logit processor
        setting to force the first decoded token to be <BOS>.
        Here, this behavior is approximated by having the
        "default" decoder prompt be <BOS>.

        However, it is possible that in the future
        other models may have different or more 
        complex logic for the default decoder prompt.
        This motivates having a special helper method
        for default decoder prompts.

        Returns:

        * prompt_token_ids
        '''

        bos_token_id = self._get_bos_token_id()
        assert bos_token_id is not None
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
        return [bos_token_id]

    def _build_enc_dec_llm_inputs(
        self,
        encoder_comps: PromptComponents,
        decoder_comps: DecoderPromptComponents,
    ) -> EncoderDecoderLLMInputs:
        encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
        decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps

        if encoder_mm_data is not None or decoder_mm_data is not None:
            raise ValueError("Multi-modal encoder-decoder models are "
                             "not supported yet")

        decoder_prompt_ids = (
            self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))

        return EncoderDecoderLLMInputs(
            prompt_token_ids=decoder_prompt_ids,
            prompt=decoder_prompt,
            encoder_prompt_token_ids=encoder_prompt_ids,
            encoder_prompt=encoder_prompt,
        )
849
850
851
852

    def _process_encoder_decoder_prompt(
        self,
        inputs: PromptInputs,
853
854
        request_id: str,
    ) -> EncoderDecoderLLMInputs:
855
856
        '''
        For encoder/decoder models only:
857
858
        Process an input prompt into an
        :class:`EncoderDecoderLLMInputs` instance.
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884

        There are two types of input prompts:
        singleton prompts which carry only the
        encoder prompt, and explicit encoder/decoder
        prompts which carry both the encoder and the
        decoder prompts as member variables.

        This function handles the following scenarios:
        * Singleton encoder prompt: extract encoder prompt
          token ids & infer default decoder prompt token ids
        * Explicit encoder/decoder prompt: extract encoder
          and decoder prompt token ids

        Note that for Explicit encoder/decoder prompts,
        each sub-prompt (encoder or decoder prompt) can
        have any possible singleton type; thus this
        method relies on helper functions to obtain
        token ids for the sub-prompts.
        
        Arguments:

        * inputs: an input prompt
        * request_id

        Returns:

885
        * :class:`EncoderDecoderLLMInputs` instance
886
887
        '''

888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
        encoder_comps: PromptComponents
        decoder_comps: DecoderPromptComponents

        if is_explicit_encoder_decoder_prompt(inputs):
            encoder_comps = self._extract_prompt_components(
                inputs["encoder_prompt"],
                request_id=request_id,
            )

            if (decoder_input := inputs["decoder_prompt"]) is None:
                decoder_comps = None, None, None
            else:
                decoder_comps = self._extract_prompt_components(
                    decoder_input,
                    request_id=request_id,
                )
904
        else:
905
906
907
908
            encoder_comps = self._extract_prompt_components(
                inputs,
                request_id=request_id,
            )
909

910
            decoder_comps = None, None, None
911

912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
        return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)

    def _build_decoder_only_llm_inputs(
        self,
        prompt_comps: PromptComponents,
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> LLMInputs:
        prompt, prompt_token_ids, multi_modal_data = prompt_comps

        prompt_token_ids = self._apply_prompt_adapter(
            prompt_token_ids, prompt_adapter_request=prompt_adapter_request)

        return LLMInputs(prompt_token_ids=prompt_token_ids,
                         prompt=prompt,
                         multi_modal_data=multi_modal_data)
927
928

    def _process_decoder_only_prompt(
929
        self,
930
931
        inputs: SingletonPromptInputs,
        request_id: str,
932
        lora_request: Optional[LoRARequest] = None,
933
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
934
    ) -> LLMInputs:
935
936
        '''
        For decoder-only models:
937
        Process an input prompt into an :class:`LLMInputs` instance.
938
939
940
941
942

        Arguments:

        * inputs: input prompt
        * request_id
943
        * lora_request
944
945
946
947
        * prompt_adapter_request

        Returns:

948
        * :class:`LLMInputs` instance
949
950
        '''

951
952
953
954
955
        prompt_comps = self._extract_prompt_components(
            inputs,
            request_id=request_id,
            lora_request=lora_request,
        )
956

957
958
959
960
        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )
961
962
963
964

    def process_model_inputs(
        self,
        inputs: PromptInputs,
965
        request_id: str,
966
967
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
968
    ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
969

970
971
972
973
974
975
976
977
        if self.is_encoder_decoder_model():
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
            model_inputs = self._process_encoder_decoder_prompt(
                inputs,
                request_id=request_id,
            )
        else:
978
979
980
981
            if is_explicit_encoder_decoder_prompt(inputs):
                raise ValueError("Cannot pass encoder-decoder prompt "
                                 "to decoder-only models")

982
983
984
985
986
987
988
989
990
            # Decoder-only operation
            model_inputs = self._process_decoder_only_prompt(
                inputs,
                request_id=request_id,
                lora_request=lora_request,
                prompt_adapter_request=prompt_adapter_request,
            )

        return self.input_processor(model_inputs)
991

992
993
994
    def add_request(
        self,
        request_id: str,
995
        inputs: PromptInputs,
996
        params: Union[SamplingParams, PoolingParams],
997
        arrival_time: Optional[float] = None,
998
        lora_request: Optional[LoRARequest] = None,
999
        trace_headers: Optional[Mapping[str, str]] = None,
1000
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1001
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
1002
        """Add a request to the engine's request pool.
1003
1004

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
1005
        scheduler as `engine.step()` is called. The exact scheduling policy is
1006
1007
1008
1009
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
1010
1011
1012
1013
1014
1015
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
1016
            arrival_time: The arrival time of the request. If None, we use
1017
                the current monotonic time.
1018
            trace_headers: OpenTelemetry trace headers.
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
1043
        """
1044
1045
1046
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
1047
        if arrival_time is None:
1048
            arrival_time = time.time()
1049

1050
        processed_inputs = self.process_model_inputs(
1051
            inputs,
1052
1053
            request_id=request_id,
            lora_request=lora_request,
1054
1055
            prompt_adapter_request=prompt_adapter_request,
        )
1056

1057
1058
1059
1060
1061
1062
        self._add_processed_request(
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
1063
            prompt_adapter_request=prompt_adapter_request,
1064
            trace_headers=trace_headers,
1065
        )
1066
1067
1068
1069
1070
1071

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
1072
1073
        arrival_time: float,
        lora_request: Optional[LoRARequest],
1074
        trace_headers: Optional[Mapping[str, str]] = None,
1075
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1076
        encoder_seq: Optional[Sequence] = None,
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

1087
1088
1089
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
1090

1091
        sampling_params.update_from_generation_config(
1092
            self.generation_config_fields, seq.eos_token_id)
1093

1094
        # Create the sequence group.
1095
1096
1097
1098
1099
1100
1101
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
1102
1103
            prompt_adapter_request=prompt_adapter_request,
            encoder_seq=encoder_seq)
1104

1105
1106
1107
1108
1109
1110
1111
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
1112
1113
        arrival_time: float,
        lora_request: Optional[LoRARequest],
1114
        prompt_adapter_request: Optional[PromptAdapterRequest],
1115
        encoder_seq: Optional[Sequence] = None,
1116
1117
1118
1119
1120
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
1121
1122
1123
1124
1125
1126
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
1127
1128
            prompt_adapter_request=prompt_adapter_request,
            encoder_seq=encoder_seq)
1129
        return seq_group
1130

Antoni Baum's avatar
Antoni Baum committed
1131
1132
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
1133
1134

        Args:
Antoni Baum's avatar
Antoni Baum committed
1135
            request_id: The ID(s) of the request to abort.
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
1147
        """
1148
1149
        for scheduler in self.scheduler:
            scheduler.abort_seq_group(request_id)
1150

1151
1152
1153
1154
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

1155
1156
1157
1158
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

1159
1160
1161
1162
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

1163
1164
1165
1166
1167
1168
1169
1170
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

1171
    def get_num_unfinished_requests(self) -> int:
1172
        """Gets the number of unfinished requests."""
1173
1174
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
1175

1176
    def has_unfinished_requests(self) -> bool:
1177
        """Returns True if there are unfinished requests."""
1178
1179
1180
1181
1182
1183
1184
1185
1186
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
1187

1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
    def _process_sequence_group_outputs(
        self,
        seq_group: SequenceGroup,
        outputs: List[EmbeddingSequenceGroupOutput],
    ) -> None:
        seq_group.embeddings = outputs[0].embeddings

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

1200
    def _process_model_outputs(
1201
        self,
1202
        output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
1203
        scheduled_seq_groups: List[ScheduledSequenceGroup],
1204
1205
        ignored_seq_groups: List[SequenceGroup],
        seq_group_metadata_list: List[SequenceGroupMetadata],
1206
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
1207
        """Apply the model output to the sequences in the scheduled seq groups.
1208

1209
1210
1211
        Returns RequestOutputs that can be returned to the client.
        """

1212
        now = time.time()
1213
1214
1215
1216

        # Organize outputs by [sequence group][step] instead of
        # [step][sequence group].
        output_by_sequence_group = create_output_by_sequence_group(
1217
            output, num_seq_groups=len(scheduled_seq_groups))
1218

1219
        # Update the scheduled sequence groups with the model outputs.
1220
1221
1222
        for scheduled_seq_group, outputs, seq_group_meta in zip(
                scheduled_seq_groups, output_by_sequence_group,
                seq_group_metadata_list):
1223
            seq_group = scheduled_seq_group.seq_group
1224
1225
            seq_group.update_num_computed_tokens(
                scheduled_seq_group.token_chunk_size)
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
            if output is not None and len(output) > 0:
                for o in output:
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
                                o.model_forward_time)
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
                                o.model_execute_time)
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1242
1243
1244
            if self.model_config.embedding_mode:
                self._process_sequence_group_outputs(seq_group, outputs)
                continue
1245

1246
1247
            self.output_processor.process_prompt_logprob(seq_group, outputs)
            if seq_group_meta.do_sample:
1248
                self.output_processor.process_outputs(seq_group, outputs)
1249
1250

        # Free the finished sequence groups.
1251
1252
        for scheduler in self.scheduler:
            scheduler.free_finished_seq_groups()
1253
1254

        # Create the outputs.
1255
1256
        request_outputs: List[Union[RequestOutput,
                                    EmbeddingRequestOutput]] = []
1257
1258
        for scheduled_seq_group in scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
1259
            seq_group.maybe_set_first_token_time(now)
1260
            request_output = RequestOutputFactory.create(seq_group)
1261
            request_outputs.append(request_output)
1262
        for seq_group in ignored_seq_groups:
1263
            request_output = RequestOutputFactory.create(seq_group)
1264
1265
1266
            request_outputs.append(request_output)
        return request_outputs

1267
    def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1268
1269
        """Performs one decoding iteration and returns newly generated results.

1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

1285
            - Step 2: Calls the distributed executor to execute the model.
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
1307
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
1308
1309
1310
1311
1312
1313
1314
1315
1316
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
1317
        """
1318
1319
1320
1321
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1322

1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
        # These are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[0]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs

        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
            seq_group_metadata_list, scheduler_outputs = self.scheduler[
                0].schedule()

            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
                    0, seq_group_metadata_list, scheduler_outputs)

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1345

1346
        if not scheduler_outputs.is_empty():
1347
1348
            finished_requests_ids = self.scheduler[
                0].get_and_reset_finished_requests_ids()
1349
1350
1351
1352
1353
1354
1355
1356

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(0)

1357
            execute_model_req = ExecuteModelRequest(
1358
1359
1360
1361
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1362
1363
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1364
1365
1366
1367
1368
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1369
1370
            output = self.model_executor.execute_model(
                execute_model_req=execute_model_req)
1371
1372
1373
1374
1375

            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
                self._update_cached_scheduler_output(0, output)
1376
1377
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
1378

1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
            # clear the cache if we have finished all the steps
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()
            request_outputs = self._process_model_outputs(
                output, scheduler_outputs.scheduled_seq_groups,
                scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)

        else:
            request_outputs = []
1394
1395

        # Log stats.
1396
        self.do_log_stats(scheduler_outputs, output)
1397

1398
1399
1400
        # Tracing
        self.do_tracing(scheduler_outputs)

1401
        if not self.has_unfinished_requests():
1402
1403
1404
1405
1406
1407
1408
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
            self.model_executor.stop_remote_worker_execution_loop()

1409
        return request_outputs
Antoni Baum's avatar
Antoni Baum committed
1410

1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
            raise AssertionError(("All running sequence groups should "
                                  "have the same remaining steps."))

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
            scheduler_outputs: SchedulerOutputs) -> None:
        self.cached_scheduler_outputs[
            virtual_engine].seq_group_metadata_list = seq_group_metadata_list
        self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
            scheduler_outputs
        self.cached_scheduler_outputs[virtual_engine].last_output = None

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None

1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1475
1476
1477
1478
    def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1479
1480
        """Forced log when no requests active."""
        if self.log_stats:
1481
            stats = self._get_stats(scheduler_outputs, model_output)
1482
            for logger in self.stat_loggers.values():
1483
                logger.log(stats)
1484

1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
    def _get_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs],
            model_output: Optional[List[SamplerOutput]] = None) -> Stats:
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
        """
1497
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1498

1499
1500
        # System State
        #   Scheduler State
1501
1502
1503
1504
1505
1506
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1507
1508

        # KV Cache Usage in %
1509
        num_total_gpu = self.cache_config.num_gpu_blocks
1510
1511
        gpu_cache_usage_sys = 0.
        if num_total_gpu is not None:
1512
1513
1514
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1515
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1516

1517
        num_total_cpu = self.cache_config.num_cpu_blocks
1518
        cpu_cache_usage_sys = 0.
1519
        if num_total_cpu is not None and num_total_cpu > 0:
1520
1521
1522
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1523
1524
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1525
1526
1527
1528
1529
1530
1531
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1532
1533
1534
1535
1536
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1537
1538
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        best_of_requests: List[int] = []
        n_requests: List[int] = []
        finished_reason_requests: List[str] = []

        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1552
        if scheduler_outputs is not None:
1553
            num_generation_tokens_from_prefill_groups = 0.
1554
1555
1556
1557
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1558
1559
1560
1561

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1562
                seq_group = scheduled_seq_group.seq_group
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
                        latency = seq_group.get_last_latency(now)
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
                    latency = seq_group.get_last_latency(now)
                    time_per_output_tokens_iter.append(latency)

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1591
                if seq_group.is_finished():
1592
                    # Latency timings
1593
1594
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
1595

1596
1597
1598
1599
1600
1601
1602
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
1603
1604
1605
1606
                    if seq_group.sampling_params is not None:
                        best_of_requests.append(
                            seq_group.sampling_params.best_of)
                        n_requests.append(seq_group.sampling_params.n)
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
                scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
                num_generation_tokens_from_prefill_groups)
1621

1622
1623
1624
1625
1626
1627
1628
1629
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
        if model_output and (model_output[0].spec_decode_worker_metrics
                             is not None):
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1630
1631
        return Stats(
            now=now,
1632
1633
1634
1635
1636
1637
1638
1639
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1640
1641
1642
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1643
1644
1645
1646
1647
1648

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1649
            spec_decode_metrics=spec_decode_metrics,
1650
            num_preemption_iter=num_preemption_iter,
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
            best_of_requests=best_of_requests,
            n_requests=n_requests,
            finished_reason_requests=finished_reason_requests,
1661
1662
        )

1663
    def add_lora(self, lora_request: LoRARequest) -> bool:
1664
        return self.model_executor.add_lora(lora_request)
1665
1666

    def remove_lora(self, lora_id: int) -> bool:
1667
        return self.model_executor.remove_lora(lora_id)
1668

1669
    def list_loras(self) -> Set[int]:
1670
        return self.model_executor.list_loras()
1671

1672
1673
1674
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

1685
    def check_health(self) -> None:
1686
1687
        if self.tokenizer:
            self.tokenizer.check_health()
1688
        self.model_executor.check_health()
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

    def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None:
        if self.tracer is None:
            return

        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
            # attribute names are based on
            # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
            seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
                                   self.model_config.model)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
                                   seq_group.request_id)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
                                   seq_group.sampling_params.temperature)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
                                   seq_group.sampling_params.top_p)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
                                   seq_group.sampling_params.max_tokens)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
                                   seq_group.sampling_params.best_of)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
                                   seq_group.sampling_params.n)
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
                                   seq_group.num_seqs())
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
                SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
                                   metrics.time_in_queue)
            seq_span.set_attribute(
                SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER,
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD,
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
                    metrics.model_execute_time)
1760
1761

    def is_encoder_decoder_model(self):
1762
        return self.model_config.is_encoder_decoder_model
1763
1764

    def is_embedding_model(self):
1765
        return self.model_config.is_embedding_model
1766
1767
1768
1769
1770
1771
1772

    def _validate_model_inputs(self, inputs: Union[LLMInputs,
                                                   EncoderDecoderLLMInputs]):
        prompt_key = "encoder_prompt_token_ids" \
            if self.is_encoder_decoder_model() else "prompt_token_ids"
        if not inputs.get(prompt_key):
            raise ValueError("Prompt cannot be empty")