llm_engine.py 64.5 KB
Newer Older
Antoni Baum's avatar
Antoni Baum committed
1
import time
2
from contextlib import contextmanager
3
4
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
                    Mapping, Optional)
5
from typing import Sequence as GenericSequence
6
from typing import Set, Tuple, Type, TypeVar, Union
7

8
9
from typing_extensions import assert_never

10
import vllm.envs as envs
11
12
13
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                         MultiModalConfig, ObservabilityConfig, ParallelConfig,
14
                         PromptAdapterConfig, SchedulerConfig,
15
                         SpeculativeConfig)
16
17
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.engine.arg_utils import EngineArgs
19
20
from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
                                 StatLoggerBase, Stats)
21
22
23
24
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
25
from vllm.executor.executor_base import ExecutorBase
26
from vllm.executor.ray_utils import initialize_ray_cluster
27
28
29
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs,
                         PromptInputs, SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.logger import init_logger
31
from vllm.lora.request import LoRARequest
32
from vllm.multimodal import MultiModalDataDict
33
34
35
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
36
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.sampling_params import SamplingParams
38
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
39
40
                           PoolerOutput, SamplerOutput, Sequence,
                           SequenceGroup, SequenceGroupMetadata,
41
                           SequenceStatus)
42
43
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
44
from vllm.transformers_utils.config import try_get_generation_config
45
from vllm.transformers_utils.detokenizer import Detokenizer
46
47
from vllm.transformers_utils.tokenizer_group import (
    AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
48
49
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
50
from vllm.utils import Counter
51
from vllm.version import __version__ as VLLM_VERSION
52
53

logger = init_logger(__name__)
54
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
55

56

57
58
59
60
61
62
63
64
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
    config = try_get_generation_config(
        model_config.model,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.revision,
    )

    if config is None:
65
66
        return {}

67
68
    return config.to_diff_dict()

69

70
71
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)

72
73
74
75
76
PromptComponents = Tuple[Optional[str], List[int],
                         Optional[MultiModalDataDict]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
                                Optional[MultiModalDataDict]]

77

78
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
79
    """An LLM engine that receives requests and generates texts.
80

Woosuk Kwon's avatar
Woosuk Kwon committed
81
    This is the main class for the vLLM engine. It receives requests
82
83
84
85
86
87
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

88
89
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
90

91
92
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
    :ref:`engine_args`)
93
94
95
96
97
98
99

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
100
        device_config: The configuration related to the device.
101
        lora_config (Optional): The configuration related to serving multi-LoRA.
102
103
        multimodal_config (Optional): The configuration related to multimodal 
            models.
104
105
        speculative_config (Optional): The configuration related to speculative
            decoding.
106
107
        executor_class: The model executor class for managing distributed
            execution.
108
109
        prompt_adapter_config (Optional): The configuration related to serving 
            prompt adapters.
110
        log_stats: Whether to log statistics.
111
        usage_context: Specified entry point, used for usage info collection.
112
    """
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

        return output

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

165
166
167
168
169
170
    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
171
        device_config: DeviceConfig,
172
        load_config: LoadConfig,
173
        lora_config: Optional[LoRAConfig],
174
        multimodal_config: Optional[MultiModalConfig],
175
        speculative_config: Optional[SpeculativeConfig],
176
        decoding_config: Optional[DecodingConfig],
177
        observability_config: Optional[ObservabilityConfig],
178
        prompt_adapter_config: Optional[PromptAdapterConfig],
179
        executor_class: Type[ExecutorBase],
180
        log_stats: bool,
yhu422's avatar
yhu422 committed
181
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
182
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
183
184
    ) -> None:
        logger.info(
185
186
187
            "Initializing an LLM engine (v%s) with config: "
            "model=%r, speculative_config=%r, tokenizer=%r, "
            "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
188
            "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
189
190
            "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
            "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
191
            "pipeline_parallel_size=%d, "
192
193
            "disable_custom_all_reduce=%s, quantization=%s, "
            "enforce_eager=%s, kv_cache_dtype=%s, "
194
            "quantization_param_path=%s, device_config=%s, "
195
            "decoding_config=%r, observability_config=%r, "
196
197
            "seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
            "enable_prefix_caching=%s)",
198
            VLLM_VERSION,
199
200
201
202
203
204
            model_config.model,
            speculative_config,
            model_config.tokenizer,
            model_config.skip_tokenizer_init,
            model_config.tokenizer_mode,
            model_config.revision,
205
            model_config.rope_scaling,
206
            model_config.rope_theta,
207
208
209
210
211
212
213
            model_config.tokenizer_revision,
            model_config.trust_remote_code,
            model_config.dtype,
            model_config.max_model_len,
            load_config.download_dir,
            load_config.load_format,
            parallel_config.tensor_parallel_size,
214
            parallel_config.pipeline_parallel_size,
215
216
217
218
219
220
221
            parallel_config.disable_custom_all_reduce,
            model_config.quantization,
            model_config.enforce_eager,
            cache_config.cache_dtype,
            model_config.quantization_param_path,
            device_config.device,
            decoding_config,
222
            observability_config,
223
            model_config.seed,
224
            model_config.served_model_name,
225
226
            scheduler_config.use_v2_block_manager,
            cache_config.enable_prefix_caching,
227
        )
228
229
230
231
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
232
        self.lora_config = lora_config
233
        self.multimodal_config = multimodal_config
234
235
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
236
        self.device_config = device_config
237
        self.speculative_config = speculative_config
238
        self.load_config = load_config
239
        self.decoding_config = decoding_config or DecodingConfig()
240
        self.prompt_adapter_config = prompt_adapter_config
241
242
        self.observability_config = observability_config or ObservabilityConfig(
        )
243
244
        self.log_stats = log_stats

245
        if not self.model_config.skip_tokenizer_init:
246
            self.tokenizer = self._init_tokenizer()
247
            self.detokenizer = Detokenizer(self.tokenizer)
248
            tokenizer_group = self.get_tokenizer_group()
249
250
        else:
            self.tokenizer = None
251
            self.detokenizer = None
252
253
254
255
256
257
258
259
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
260

261
        self.seq_counter = Counter()
262
263
        self.generation_config_fields = _load_generation_config_dict(
            model_config)
264

265
266
267
        self.input_processor = INPUT_REGISTRY.create_input_processor(
            self.model_config)

268
269
270
271
272
273
274
        self.model_executor = executor_class(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
275
            multimodal_config=multimodal_config,
276
            speculative_config=speculative_config,
277
            load_config=load_config,
278
            prompt_adapter_config=prompt_adapter_config,
279
            observability_config=self.observability_config,
280
        )
281

282
283
        if not self.model_config.embedding_mode:
            self._initialize_kv_caches()
284

yhu422's avatar
yhu422 committed
285
286
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
287
288
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
            usage_message.report_usage(
                get_architecture_class_name(model_config),
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
                    str(model_config.dtype),
                    "tensor_parallel_size":
                    parallel_config.tensor_parallel_size,
                    "block_size":
                    cache_config.block_size,
                    "gpu_memory_utilization":
                    cache_config.gpu_memory_utilization,

                    # Quantization
                    "quantization":
                    model_config.quantization,
                    "kv_cache_dtype":
307
                    str(cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
308
309
310
311

                    # Feature flags
                    "enable_lora":
                    bool(lora_config),
312
313
                    "enable_prompt_adapter":
                    bool(prompt_adapter_config),
yhu422's avatar
yhu422 committed
314
315
316
317
318
319
320
321
                    "enable_prefix_caching":
                    cache_config.enable_prefix_caching,
                    "enforce_eager":
                    model_config.enforce_eager,
                    "disable_custom_all_reduce":
                    parallel_config.disable_custom_all_reduce,
                })

322
323
324
325
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
326

327
        # Create the scheduler.
328
329
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
330
331
332
333
334
        self.scheduler = [
            Scheduler(scheduler_config, cache_config, lora_config,
                      parallel_config.pipeline_parallel_size)
            for _ in range(parallel_config.pipeline_parallel_size)
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
335

336
337
        # Metric Logging.
        if self.log_stats:
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        labels=dict(model_name=model_config.served_model_name),
                        max_model_len=self.model_config.max_model_len),
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
353

354
355
356
357
358
359
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

360
361
362
363
364
365
366
367
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
368
                get_tokenizer_for_seq,
369
370
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
371
                    get_tokenizer_for_seq,
372
373
374
                ),
            ))

375
376
377
378
379
380
381
382
383
384
385
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
386
387
388
389
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
390
391
392
393
394
395
396
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

397
    @classmethod
398
399
    def _get_executor_cls(cls,
                          engine_config: EngineConfig) -> Type[ExecutorBase]:
400
401
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
402
        # Initialize the cluster and specify the executor class.
403
404
405
406
407
408
409
410
411
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
412
413
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
414
        elif engine_config.device_config.device_type == "tpu":
415
416
417
418
419
420
421
422
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutor
                executor_class = RayTPUExecutor
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutor
                executor_class = TPUExecutor
423
        elif engine_config.device_config.device_type == "cpu":
424
425
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
426
427
428
        elif engine_config.device_config.device_type == "openvino":
            from vllm.executor.openvino_executor import OpenVINOExecutor
            executor_class = OpenVINOExecutor
429
430
431
432
433
434
435
436
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutor
                executor_class = RayXPUExecutor
            else:
                from vllm.executor.xpu_executor import XPUExecutor
                executor_class = XPUExecutor
437
        elif distributed_executor_backend == "ray":
438
            initialize_ray_cluster(engine_config.parallel_config)
439
440
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
441
442
443
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutor)
444
445
446
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
447
            executor_class = MultiprocessingGPUExecutor
448
449
450
        else:
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor
451
452
453
454
455
456
457
458
459
460
461
462
463
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()
        executor_class = cls._get_executor_cls(engine_config)
464
        # Create the LLM engine.
yhu422's avatar
yhu422 committed
465
        engine = cls(
466
            **engine_config.to_dict(),
yhu422's avatar
yhu422 committed
467
468
469
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
470
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
471
        )
472

473
        return engine
474

475
476
477
478
479
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

480
481
482
483
484
485
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

486
487
488
489
490
491
492
493
494
495
496
    MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
                                   "skip_tokenizer_init is True")

    def get_tokenizer_group(
            self,
            fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
        if self.tokenizer is None:
            raise ValueError(fail_msg)

        return self.tokenizer

497
    def get_tokenizer(
498
499
500
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
501
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
502

503
504
505
506
507
508
    def _init_tokenizer(self) -> BaseTokenizerGroup:
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
            parallel_config=self.parallel_config,
            enable_lora=bool(self.lora_config))
509

510
511
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
512
        self.cache_config.verify_with_parallel_config(self.parallel_config)
513
514
515
516
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
517
518
519
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
520

521
522
523
524
525
526
527
528
529
530
531
532
533
    def _get_bos_token_id(self,
                          lora_request: Optional[LoRARequest] = None
                          ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for BOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id

    def _get_eos_token_id(self,
                          lora_request: Optional[LoRARequest] = None
                          ) -> Optional[int]:
534
535
536
537
538
539
540
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

541
    def _get_decoder_start_token_id(self) -> Optional[int]:
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
        '''
        Obtain the decoder start token id employed by an encoder/decoder
        model. Returns None for non-encoder/decoder models or if the
        model config is unavailable.
        '''

        if not self.is_encoder_decoder_model():
            logger.warning("Using None for decoder start token id because "
                           "this is not an encoder/decoder model.")
            return None

        if (self.model_config is None or self.model_config.hf_config is None):
            logger.warning("Using None for decoder start token id because "
                           "model config is not available.")
            return None

        dec_start_token_id = getattr(self.model_config.hf_config,
                                     'decoder_start_token_id', None)
        if dec_start_token_id is None:
            logger.warning("Falling back on <BOS> for decoder start token id "
                           "because decoder start token id is not available.")
            dec_start_token_id = self._get_bos_token_id()

        return dec_start_token_id

567
568
569
    def _add_processed_request(
        self,
        request_id: str,
570
        processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
571
572
573
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
574
        prompt_adapter_request: Optional[PromptAdapterRequest],
575
        trace_headers: Optional[Mapping[str, str]] = None,
576
577
578
579
580
581
582
    ) -> None:
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
        eos_token_id = self._get_eos_token_id(lora_request)

        seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
583
                       lora_request, prompt_adapter_request)
584

585
586
587
588
589
590
591
592
593
594
        encoder_seq = None
        if 'encoder_prompt_token_ids' in processed_inputs:
            encoder_seq = Sequence(seq_id,
                                   processed_inputs,
                                   block_size,
                                   eos_token_id,
                                   lora_request,
                                   prompt_adapter_request,
                                   from_decoder_prompt=False)

595
596
597
598
599
600
601
602
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
603
                trace_headers=trace_headers,
604
605
                prompt_adapter_request=prompt_adapter_request,
                encoder_seq=encoder_seq)
606
607
608
609
610
611
612
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
613
614
                prompt_adapter_request=prompt_adapter_request,
                encoder_seq=encoder_seq)
615
616
617
618
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

619
620
621
622
623
624
625
626
627
628
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
629

630
    _LLMInputComponentsType = Tuple[str, List[int]]
631
632
633

    def _prepare_decoder_input_ids_for_generation(
        self,
634
        decoder_input_ids: Optional[List[int]],
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
    ) -> List[int]:
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

        Based on

        https://github.com/huggingface/transformers/blob/
        4037a2b5b1278736e566aec12e169100275545ea/
        src/transformers/generation/utils.py

        specifically GenerationMixin._prepare_decoder_input_ids_for_generation()

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """

656
        decoder_start_token_id = self._get_decoder_start_token_id()
657
658
659
660
661
        assert decoder_start_token_id is not None

        if decoder_input_ids is None:
            # no decoder prompt input ->
            # use decoder_start_token_id as decoder_input_ids
662
            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
663
664
665
666
667
668
669
670
671
672

        if (len(decoder_input_ids) == 0
                or decoder_input_ids[0] != decoder_start_token_id):
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

    def _tokenize_prompt(
        self,
        prompt: str,
673
674
        request_id: str,
        lora_request: Optional[LoRARequest],
675
676
    ) -> List[int]:
        '''
677
        Wrapper around application of the model's tokenizer.
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692

        Arguments:

        * prompt
        * request_id
        * lora_request

        Returns:

        * prompt token ids
        '''

        tokenizer = self.get_tokenizer_group("prompts must be None if "
                                             "skip_tokenizer_init is True")

693
694
695
        return tokenizer.encode(request_id=request_id,
                                prompt=prompt,
                                lora_request=lora_request)
696

697
    def _extract_prompt_components(
698
        self,
699
700
701
702
        inputs: SingletonPromptInputs,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
    ) -> PromptComponents:
703
        '''
704
        Extract the components of any single encoder or decoder input prompt.
705
706
707
708
709

        Arguments:

        * request_id
        * inputs: single encoder or decoder input prompt
710
        * lora_request: this is only valid for decoder prompts
711
712
713
714
715

        Returns:

        * prompt
        * prompt_token_ids
716
        * multi_modal_data
717
718
        '''

719
        if isinstance(inputs, str):
720
721
722
723
            prompt = inputs
            prompt_token_ids = self._tokenize_prompt(
                prompt,
                request_id=request_id,
724
                lora_request=lora_request,
725
            )
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
            multi_modal_data = None
        elif isinstance(inputs, dict):
            if "prompt_token_ids" in inputs:
                prompt = None
                prompt_token_ids = inputs["prompt_token_ids"]
            else:
                # NOTE: This extra assignment is required to pass mypy
                prompt = parsed_prompt = inputs["prompt"]
                prompt_token_ids = self._tokenize_prompt(
                    parsed_prompt,
                    request_id=request_id,
                    lora_request=lora_request,
                )

            multi_modal_data = inputs.get("multi_modal_data")
741
        else:
742
            assert_never(inputs)
743

744
        return prompt, prompt_token_ids, multi_modal_data
745

746
747
748
749
750
751
752
753
754
    def _apply_prompt_adapter(
        self,
        prompt_token_ids: List[int],
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> List[int]:
        if prompt_adapter_request:
            prompt_token_ids = (
                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
                + prompt_token_ids)
755

756
        return prompt_token_ids
757

758
    def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
        '''
        Specifically for encoder/decoder models:
        generate a default decoder prompt for when
        the user specifies only the encoder prompt.

        Encoder/decoder models utilize the decoder
        prompt in different ways; as new models are
        added, it is intended that this function
        will be extended to produce differing
        default decoder prompts, depending on the
        model variety.

        Absent a special case, the default behavior
        of this method is to mirror the behavior of
        the HuggingFace (HF) GenerationMixin for a None
        decoder prompt, which is to employ a logit processor
        setting to force the first decoded token to be <BOS>.
        Here, this behavior is approximated by having the
        "default" decoder prompt be <BOS>.

        However, it is possible that in the future
        other models may have different or more 
        complex logic for the default decoder prompt.
        This motivates having a special helper method
        for default decoder prompts.

        Returns:

        * prompt_token_ids
        '''

        bos_token_id = self._get_bos_token_id()
        assert bos_token_id is not None
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
        return [bos_token_id]

    def _build_enc_dec_llm_inputs(
        self,
        encoder_comps: PromptComponents,
        decoder_comps: DecoderPromptComponents,
    ) -> EncoderDecoderLLMInputs:
        encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
        decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps

        if encoder_mm_data is not None or decoder_mm_data is not None:
            raise ValueError("Multi-modal encoder-decoder models are "
                             "not supported yet")

        decoder_prompt_ids = (
            self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))

        return EncoderDecoderLLMInputs(
            prompt_token_ids=decoder_prompt_ids,
            prompt=decoder_prompt,
            encoder_prompt_token_ids=encoder_prompt_ids,
            encoder_prompt=encoder_prompt,
        )
815
816
817
818

    def _process_encoder_decoder_prompt(
        self,
        inputs: PromptInputs,
819
820
        request_id: str,
    ) -> EncoderDecoderLLMInputs:
821
822
        '''
        For encoder/decoder models only:
823
824
        Process an input prompt into an
        :class:`EncoderDecoderLLMInputs` instance.
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850

        There are two types of input prompts:
        singleton prompts which carry only the
        encoder prompt, and explicit encoder/decoder
        prompts which carry both the encoder and the
        decoder prompts as member variables.

        This function handles the following scenarios:
        * Singleton encoder prompt: extract encoder prompt
          token ids & infer default decoder prompt token ids
        * Explicit encoder/decoder prompt: extract encoder
          and decoder prompt token ids

        Note that for Explicit encoder/decoder prompts,
        each sub-prompt (encoder or decoder prompt) can
        have any possible singleton type; thus this
        method relies on helper functions to obtain
        token ids for the sub-prompts.
        
        Arguments:

        * inputs: an input prompt
        * request_id

        Returns:

851
        * :class:`EncoderDecoderLLMInputs` instance
852
853
        '''

854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
        encoder_comps: PromptComponents
        decoder_comps: DecoderPromptComponents

        if is_explicit_encoder_decoder_prompt(inputs):
            encoder_comps = self._extract_prompt_components(
                inputs["encoder_prompt"],
                request_id=request_id,
            )

            if (decoder_input := inputs["decoder_prompt"]) is None:
                decoder_comps = None, None, None
            else:
                decoder_comps = self._extract_prompt_components(
                    decoder_input,
                    request_id=request_id,
                )
870
        else:
871
872
873
874
            encoder_comps = self._extract_prompt_components(
                inputs,
                request_id=request_id,
            )
875

876
            decoder_comps = None, None, None
877

878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
        return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)

    def _build_decoder_only_llm_inputs(
        self,
        prompt_comps: PromptComponents,
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> LLMInputs:
        prompt, prompt_token_ids, multi_modal_data = prompt_comps

        prompt_token_ids = self._apply_prompt_adapter(
            prompt_token_ids, prompt_adapter_request=prompt_adapter_request)

        return LLMInputs(prompt_token_ids=prompt_token_ids,
                         prompt=prompt,
                         multi_modal_data=multi_modal_data)
893
894

    def _process_decoder_only_prompt(
895
        self,
896
897
        inputs: SingletonPromptInputs,
        request_id: str,
898
        lora_request: Optional[LoRARequest] = None,
899
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
900
    ) -> LLMInputs:
901
902
        '''
        For decoder-only models:
903
        Process an input prompt into an :class:`LLMInputs` instance.
904
905
906
907
908

        Arguments:

        * inputs: input prompt
        * request_id
909
        * lora_request
910
911
912
913
        * prompt_adapter_request

        Returns:

914
        * :class:`LLMInputs` instance
915
916
        '''

917
918
919
920
921
        prompt_comps = self._extract_prompt_components(
            inputs,
            request_id=request_id,
            lora_request=lora_request,
        )
922

923
924
925
926
        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )
927
928
929
930

    def process_model_inputs(
        self,
        inputs: PromptInputs,
931
        request_id: str,
932
933
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
934
    ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
935

936
937
938
939
940
941
942
943
        if self.is_encoder_decoder_model():
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
            model_inputs = self._process_encoder_decoder_prompt(
                inputs,
                request_id=request_id,
            )
        else:
944
945
946
947
            if is_explicit_encoder_decoder_prompt(inputs):
                raise ValueError("Cannot pass encoder-decoder prompt "
                                 "to decoder-only models")

948
949
950
951
952
953
954
955
956
            # Decoder-only operation
            model_inputs = self._process_decoder_only_prompt(
                inputs,
                request_id=request_id,
                lora_request=lora_request,
                prompt_adapter_request=prompt_adapter_request,
            )

        return self.input_processor(model_inputs)
957

958
959
960
    def add_request(
        self,
        request_id: str,
961
        inputs: PromptInputs,
962
        params: Union[SamplingParams, PoolingParams],
963
        arrival_time: Optional[float] = None,
964
        lora_request: Optional[LoRARequest] = None,
965
        trace_headers: Optional[Mapping[str, str]] = None,
966
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
967
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
968
        """Add a request to the engine's request pool.
969
970

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
971
        scheduler as `engine.step()` is called. The exact scheduling policy is
972
973
974
975
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
976
977
978
979
980
981
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
982
            arrival_time: The arrival time of the request. If None, we use
983
                the current monotonic time.
984
            trace_headers: OpenTelemetry trace headers.
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
1009
        """
1010
1011
1012
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
1013
        if arrival_time is None:
1014
            arrival_time = time.time()
1015

1016
        processed_inputs = self.process_model_inputs(
1017
            inputs,
1018
1019
            request_id=request_id,
            lora_request=lora_request,
1020
1021
            prompt_adapter_request=prompt_adapter_request,
        )
1022

1023
1024
1025
1026
1027
1028
        self._add_processed_request(
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
1029
            prompt_adapter_request=prompt_adapter_request,
1030
            trace_headers=trace_headers,
1031
        )
1032
1033
1034
1035
1036
1037

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
1038
1039
        arrival_time: float,
        lora_request: Optional[LoRARequest],
1040
        trace_headers: Optional[Mapping[str, str]] = None,
1041
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1042
        encoder_seq: Optional[Sequence] = None,
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

1053
1054
1055
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
1056

1057
        sampling_params.update_from_generation_config(
1058
            self.generation_config_fields, seq.eos_token_id)
1059

1060
        # Create the sequence group.
1061
1062
1063
1064
1065
1066
1067
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
1068
1069
            prompt_adapter_request=prompt_adapter_request,
            encoder_seq=encoder_seq)
1070

1071
1072
1073
1074
1075
1076
1077
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
1078
1079
        arrival_time: float,
        lora_request: Optional[LoRARequest],
1080
        prompt_adapter_request: Optional[PromptAdapterRequest],
1081
        encoder_seq: Optional[Sequence] = None,
1082
1083
1084
1085
1086
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
1087
1088
1089
1090
1091
1092
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
1093
1094
            prompt_adapter_request=prompt_adapter_request,
            encoder_seq=encoder_seq)
1095
        return seq_group
1096

Antoni Baum's avatar
Antoni Baum committed
1097
1098
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
1099
1100

        Args:
Antoni Baum's avatar
Antoni Baum committed
1101
            request_id: The ID(s) of the request to abort.
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
1113
        """
1114
1115
        for scheduler in self.scheduler:
            scheduler.abort_seq_group(request_id)
1116

1117
1118
1119
1120
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

1121
1122
1123
1124
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

1125
1126
1127
1128
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

1129
1130
1131
1132
1133
1134
1135
1136
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

1137
    def get_num_unfinished_requests(self) -> int:
1138
        """Gets the number of unfinished requests."""
1139
1140
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
1141

1142
    def has_unfinished_requests(self) -> bool:
1143
        """Returns True if there are unfinished requests."""
1144
1145
1146
1147
1148
1149
1150
1151
1152
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
1153

1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
    def _process_sequence_group_outputs(
        self,
        seq_group: SequenceGroup,
        outputs: List[EmbeddingSequenceGroupOutput],
    ) -> None:
        seq_group.embeddings = outputs[0].embeddings

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

1166
    def _process_model_outputs(
1167
        self,
1168
        output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
1169
        scheduled_seq_groups: List[ScheduledSequenceGroup],
1170
1171
        ignored_seq_groups: List[SequenceGroup],
        seq_group_metadata_list: List[SequenceGroupMetadata],
1172
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
1173
        """Apply the model output to the sequences in the scheduled seq groups.
1174

1175
1176
1177
        Returns RequestOutputs that can be returned to the client.
        """

1178
        now = time.time()
1179
1180
1181
1182

        # Organize outputs by [sequence group][step] instead of
        # [step][sequence group].
        output_by_sequence_group = create_output_by_sequence_group(
1183
            output, num_seq_groups=len(scheduled_seq_groups))
1184

1185
        # Update the scheduled sequence groups with the model outputs.
1186
1187
1188
        for scheduled_seq_group, outputs, seq_group_meta in zip(
                scheduled_seq_groups, output_by_sequence_group,
                seq_group_metadata_list):
1189
            seq_group = scheduled_seq_group.seq_group
1190
1191
            seq_group.update_num_computed_tokens(
                scheduled_seq_group.token_chunk_size)
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
            if output is not None and len(output) > 0:
                for o in output:
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
                                o.model_forward_time)
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
                                o.model_execute_time)
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1208
1209
1210
            if self.model_config.embedding_mode:
                self._process_sequence_group_outputs(seq_group, outputs)
                continue
1211

1212
1213
            self.output_processor.process_prompt_logprob(seq_group, outputs)
            if seq_group_meta.do_sample:
1214
                self.output_processor.process_outputs(seq_group, outputs)
1215
1216

        # Free the finished sequence groups.
1217
1218
        for scheduler in self.scheduler:
            scheduler.free_finished_seq_groups()
1219
1220

        # Create the outputs.
1221
1222
        request_outputs: List[Union[RequestOutput,
                                    EmbeddingRequestOutput]] = []
1223
1224
        for scheduled_seq_group in scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
1225
            seq_group.maybe_set_first_token_time(now)
1226
            request_output = RequestOutputFactory.create(seq_group)
1227
            request_outputs.append(request_output)
1228
        for seq_group in ignored_seq_groups:
1229
            request_output = RequestOutputFactory.create(seq_group)
1230
1231
1232
            request_outputs.append(request_output)
        return request_outputs

1233
    def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1234
1235
        """Performs one decoding iteration and returns newly generated results.

1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

1251
            - Step 2: Calls the distributed executor to execute the model.
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
1273
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
1274
1275
1276
1277
1278
1279
1280
1281
1282
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
1283
        """
1284
1285
1286
1287
1288
1289
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
        seq_group_metadata_list, scheduler_outputs = self.scheduler[
            0].schedule()
Antoni Baum's avatar
Antoni Baum committed
1290

1291
        if not scheduler_outputs.is_empty():
1292
1293
            finished_requests_ids = self.scheduler[
                0].get_and_reset_finished_requests_ids()
1294
            execute_model_req = ExecuteModelRequest(
1295
1296
1297
1298
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1299
1300
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
Mor Zusman's avatar
Mor Zusman committed
1301
                finished_requests_ids=finished_requests_ids)
1302
1303
            output = self.model_executor.execute_model(
                execute_model_req=execute_model_req)
1304
1305
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
1306

1307
1308
        request_outputs = self._process_model_outputs(
            output, scheduler_outputs.scheduled_seq_groups,
1309
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
1310
1311

        # Log stats.
1312
        self.do_log_stats(scheduler_outputs, output)
1313

1314
1315
1316
        # Tracing
        self.do_tracing(scheduler_outputs)

1317
        if not self.has_unfinished_requests():
1318
1319
1320
1321
1322
1323
1324
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
            self.model_executor.stop_remote_worker_execution_loop()

1325
        return request_outputs
Antoni Baum's avatar
Antoni Baum committed
1326

1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1337
1338
1339
1340
    def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1341
1342
        """Forced log when no requests active."""
        if self.log_stats:
1343
            stats = self._get_stats(scheduler_outputs, model_output)
1344
            for logger in self.stat_loggers.values():
1345
                logger.log(stats)
1346

1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
    def _get_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs],
            model_output: Optional[List[SamplerOutput]] = None) -> Stats:
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
        """
1359
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1360

1361
1362
        # System State
        #   Scheduler State
1363
1364
1365
1366
1367
1368
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1369
1370

        # KV Cache Usage in %
1371
        num_total_gpu = self.cache_config.num_gpu_blocks
1372
1373
        gpu_cache_usage_sys = 0.
        if num_total_gpu is not None:
1374
1375
1376
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1377
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1378

1379
        num_total_cpu = self.cache_config.num_cpu_blocks
1380
        cpu_cache_usage_sys = 0.
1381
        if num_total_cpu is not None and num_total_cpu > 0:
1382
1383
1384
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1385
1386
1387
1388
1389
1390
1391
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1392
1393
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        best_of_requests: List[int] = []
        n_requests: List[int] = []
        finished_reason_requests: List[str] = []

        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1407
        if scheduler_outputs is not None:
1408
            num_generation_tokens_from_prefill_groups = 0.
1409
1410
1411
1412
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1413
1414
1415
1416

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1417
                seq_group = scheduled_seq_group.seq_group
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
                        latency = seq_group.get_last_latency(now)
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
                    latency = seq_group.get_last_latency(now)
                    time_per_output_tokens_iter.append(latency)

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1446
                if seq_group.is_finished():
1447
                    # Latency timings
1448
1449
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
1450

1451
1452
1453
1454
1455
1456
1457
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
1458
1459
1460
1461
                    if seq_group.sampling_params is not None:
                        best_of_requests.append(
                            seq_group.sampling_params.best_of)
                        n_requests.append(seq_group.sampling_params.n)
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
                scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
                num_generation_tokens_from_prefill_groups)
1476

1477
1478
1479
1480
1481
1482
1483
1484
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
        if model_output and (model_output[0].spec_decode_worker_metrics
                             is not None):
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1485
1486
        return Stats(
            now=now,
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1501
            spec_decode_metrics=spec_decode_metrics,
1502
            num_preemption_iter=num_preemption_iter,
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
            best_of_requests=best_of_requests,
            n_requests=n_requests,
            finished_reason_requests=finished_reason_requests,
1513
1514
        )

1515
    def add_lora(self, lora_request: LoRARequest) -> bool:
1516
        return self.model_executor.add_lora(lora_request)
1517
1518

    def remove_lora(self, lora_id: int) -> bool:
1519
        return self.model_executor.remove_lora(lora_id)
1520

1521
    def list_loras(self) -> Set[int]:
1522
        return self.model_executor.list_loras()
1523

1524
1525
1526
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

1537
    def check_health(self) -> None:
1538
1539
        if self.tokenizer:
            self.tokenizer.check_health()
1540
        self.model_executor.check_health()
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

    def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None:
        if self.tracer is None:
            return

        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
            # attribute names are based on
            # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
            seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
                                   self.model_config.model)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
                                   seq_group.request_id)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
                                   seq_group.sampling_params.temperature)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
                                   seq_group.sampling_params.top_p)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
                                   seq_group.sampling_params.max_tokens)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
                                   seq_group.sampling_params.best_of)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
                                   seq_group.sampling_params.n)
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
                                   seq_group.num_seqs())
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
                SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
                                   metrics.time_in_queue)
            seq_span.set_attribute(
                SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER,
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD,
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
                    metrics.model_execute_time)
1612
1613

    def is_encoder_decoder_model(self):
1614
        return self.model_config.is_encoder_decoder_model
1615
1616

    def is_embedding_model(self):
1617
        return self.model_config.is_embedding_model