llm_engine.py 45.6 KB
Newer Older
Antoni Baum's avatar
Antoni Baum committed
1
import time
2
from contextlib import contextmanager
3
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
4
from typing import Sequence as GenericSequence
5
from typing import Set, Type, TypeVar, Union
6

7
from transformers import PreTrainedTokenizer
8

9
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
10
11
                         LoRAConfig, ModelConfig, ObservabilityConfig,
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
12
                         VisionLanguageConfig)
13
14
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
15
from vllm.engine.arg_utils import EngineArgs
16
17
from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
                                 StatLoggerBase, Stats)
18
19
20
21
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
22
from vllm.executor.executor_base import ExecutorBase
23
from vllm.executor.ray_utils import initialize_ray_cluster
24
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm.logger import init_logger
26
from vllm.lora.request import LoRARequest
27
28
29
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.sampling_params import SamplingParams
31
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
32
33
                           PoolerOutput, SamplerOutput, Sequence,
                           SequenceGroup, SequenceGroupMetadata,
34
                           SequenceStatus)
35
36
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
37
from vllm.transformers_utils.config import try_get_generation_config
38
from vllm.transformers_utils.detokenizer import Detokenizer
39
40
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
                                                     get_tokenizer_group)
yhu422's avatar
yhu422 committed
41
42
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
43
from vllm.utils import Counter
44
from vllm.version import __version__ as VLLM_VERSION
45
46

logger = init_logger(__name__)
47
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
48

49

50
51
52
53
54
55
56
57
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
    config = try_get_generation_config(
        model_config.model,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.revision,
    )

    if config is None:
58
59
        return {}

60
61
    return config.to_diff_dict()

62

63
64
65
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)


66
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
67
    """An LLM engine that receives requests and generates texts.
68

Woosuk Kwon's avatar
Woosuk Kwon committed
69
    This is the main class for the vLLM engine. It receives requests
70
71
72
73
74
75
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

76
77
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
78

79
80
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
    :ref:`engine_args`)
81
82
83
84
85
86
87

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
88
        device_config: The configuration related to the device.
89
90
91
92
93
        lora_config (Optional): The configuration related to serving multi-LoRA.
        vision_language_config (Optional): The configuration related to vision
            language models.
        speculative_config (Optional): The configuration related to speculative
            decoding.
94
95
        executor_class: The model executor class for managing distributed
            execution.
96
        log_stats: Whether to log statistics.
97
        usage_context: Specified entry point, used for usage info collection.
98
    """
99

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

        return output

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

151
152
153
154
155
156
    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
157
        device_config: DeviceConfig,
158
        load_config: LoadConfig,
159
        lora_config: Optional[LoRAConfig],
160
161
        vision_language_config: Optional[VisionLanguageConfig],
        speculative_config: Optional[SpeculativeConfig],
162
        decoding_config: Optional[DecodingConfig],
163
        observability_config: Optional[ObservabilityConfig],
164
        executor_class: Type[ExecutorBase],
165
        log_stats: bool,
yhu422's avatar
yhu422 committed
166
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
167
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
168
169
    ) -> None:
        logger.info(
170
171
172
            "Initializing an LLM engine (v%s) with config: "
            "model=%r, speculative_config=%r, tokenizer=%r, "
            "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
173
            "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
174
175
176
177
            "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
            "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
            "disable_custom_all_reduce=%s, quantization=%s, "
            "enforce_eager=%s, kv_cache_dtype=%s, "
178
            "quantization_param_path=%s, device_config=%s, "
179
180
            "decoding_config=%r, observability_config=%r, "
            "seed=%d, served_model_name=%s)",
181
            VLLM_VERSION,
182
183
184
185
186
187
            model_config.model,
            speculative_config,
            model_config.tokenizer,
            model_config.skip_tokenizer_init,
            model_config.tokenizer_mode,
            model_config.revision,
188
            model_config.rope_scaling,
189
            model_config.rope_theta,
190
191
192
193
194
195
196
197
198
199
200
201
202
203
            model_config.tokenizer_revision,
            model_config.trust_remote_code,
            model_config.dtype,
            model_config.max_model_len,
            load_config.download_dir,
            load_config.load_format,
            parallel_config.tensor_parallel_size,
            parallel_config.disable_custom_all_reduce,
            model_config.quantization,
            model_config.enforce_eager,
            cache_config.cache_dtype,
            model_config.quantization_param_path,
            device_config.device,
            decoding_config,
204
            observability_config,
205
            model_config.seed,
206
            model_config.served_model_name,
207
        )
208
209
210
211
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
212
        self.lora_config = lora_config
213
        self.vision_language_config = vision_language_config
214
215
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
216
        self.device_config = device_config
217
        self.speculative_config = speculative_config
218
        self.load_config = load_config
219
        self.decoding_config = decoding_config or DecodingConfig()
220
221
        self.observability_config = observability_config or ObservabilityConfig(
        )
222
223
        self.log_stats = log_stats

224
        if not self.model_config.skip_tokenizer_init:
225
            self.tokenizer = self._init_tokenizer()
226
227
228
            self.detokenizer = Detokenizer(self.tokenizer)
        else:
            self.tokenizer = None
229
            self.detokenizer = None
230

231
        self.seq_counter = Counter()
232
233
        self.generation_config_fields = _load_generation_config_dict(
            model_config)
234

235
236
237
        self.input_processor = INPUT_REGISTRY.create_input_processor(
            self.model_config)

238
239
240
241
242
243
244
245
246
        self.model_executor = executor_class(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            vision_language_config=vision_language_config,
            speculative_config=speculative_config,
247
            load_config=load_config,
248
        )
249

250
251
        if not self.model_config.embedding_mode:
            self._initialize_kv_caches()
252

yhu422's avatar
yhu422 committed
253
254
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
255
256
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
            usage_message.report_usage(
                get_architecture_class_name(model_config),
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
                    str(model_config.dtype),
                    "tensor_parallel_size":
                    parallel_config.tensor_parallel_size,
                    "block_size":
                    cache_config.block_size,
                    "gpu_memory_utilization":
                    cache_config.gpu_memory_utilization,

                    # Quantization
                    "quantization":
                    model_config.quantization,
                    "kv_cache_dtype":
                    cache_config.cache_dtype,

                    # Feature flags
                    "enable_lora":
                    bool(lora_config),
                    "enable_prefix_caching":
                    cache_config.enable_prefix_caching,
                    "enforce_eager":
                    model_config.enforce_eager,
                    "disable_custom_all_reduce":
                    parallel_config.disable_custom_all_reduce,
                })

288
289
290
291
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
292

293
        # Create the scheduler.
294
295
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
296
        self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
297

298
299
        # Metric Logging.
        if self.log_stats:
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        labels=dict(model_name=model_config.served_model_name),
                        max_model_len=self.model_config.max_model_len),
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
315

316
317
318
319
320
321
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
                self.get_tokenizer_for_seq,
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
                    self.get_tokenizer_for_seq,
                ),
            ))

337
338
339
340
341
342
343
344
345
346
347
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
348
349
350
351
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
352
353
354
355
356
357
358
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

359
    @classmethod
yhu422's avatar
yhu422 committed
360
361
362
363
364
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "LLMEngine":
365
366
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
367
        engine_config = engine_args.create_engine_config()
368
369
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
370
371

        # Initialize the cluster and specify the executor class.
372
        if engine_config.device_config.device_type == "neuron":
373
374
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
375
376
377
        elif engine_config.device_config.device_type == "tpu":
            from vllm.executor.tpu_executor import TPUExecutor
            executor_class = TPUExecutor
378
        elif engine_config.device_config.device_type == "cpu":
379
380
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
381
382
383
        elif engine_config.device_config.device_type == "openvino":
            from vllm.executor.openvino_executor import OpenVINOExecutor
            executor_class = OpenVINOExecutor
384
385
386
387
388
389
390
391
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutor
                executor_class = RayXPUExecutor
            else:
                from vllm.executor.xpu_executor import XPUExecutor
                executor_class = XPUExecutor
392
        elif distributed_executor_backend == "ray":
393
            initialize_ray_cluster(engine_config.parallel_config)
394
395
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
396
397
398
399
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutor)
            executor_class = MultiprocessingGPUExecutor
400
401
402
403
404
        else:
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor

        # Create the LLM engine.
yhu422's avatar
yhu422 committed
405
        engine = cls(
406
            **engine_config.to_dict(),
yhu422's avatar
yhu422 committed
407
408
409
410
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
        )
411
        return engine
412

413
414
415
416
417
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

418
419
420
421
422
423
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

424
425
426
427
428
429
430
431
432
433
434
    MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
                                   "skip_tokenizer_init is True")

    def get_tokenizer_group(
            self,
            fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
        if self.tokenizer is None:
            raise ValueError(fail_msg)

        return self.tokenizer

435
    def get_tokenizer(self) -> "PreTrainedTokenizer":
436
        return self.get_tokenizer_group().get_lora_tokenizer(None)
437
438
439

    def get_tokenizer_for_seq(self,
                              sequence: Sequence) -> "PreTrainedTokenizer":
440
441
        return self.get_tokenizer_group().get_lora_tokenizer(
            sequence.lora_request)
442

443
    def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
444
        init_kwargs = dict(
445
            tokenizer_id=self.model_config.tokenizer,
446
447
448
449
450
451
452
            enable_lora=bool(self.lora_config),
            max_num_seqs=self.scheduler_config.max_num_seqs,
            max_input_length=None,
            tokenizer_mode=self.model_config.tokenizer_mode,
            trust_remote_code=self.model_config.trust_remote_code,
            revision=self.model_config.tokenizer_revision)
        init_kwargs.update(tokenizer_init_kwargs)
453
454
455

        return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
                                   **init_kwargs)
456

457
458
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
459
        self.cache_config.verify_with_parallel_config(self.parallel_config)
460
461
462
463
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
464

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    def _get_eos_token_id(
            self, lora_request: Optional[LoRARequest]) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

    def _add_processed_request(
        self,
        request_id: str,
        processed_inputs: LLMInputs,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
481
        trace_headers: Optional[Dict[str, str]] = None,
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    ) -> None:
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
        eos_token_id = self._get_eos_token_id(lora_request)

        seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
                       lora_request)

        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
499
                trace_headers=trace_headers,
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
            )
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
            )
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

    def process_model_inputs(
517
        self,
518
519
        request_id: str,
        inputs: PromptInputs,
520
        lora_request: Optional[LoRARequest] = None,
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    ) -> LLMInputs:
        if isinstance(inputs, str):
            inputs = {"prompt": inputs}

        if "prompt_token_ids" not in inputs:
            tokenizer = self.get_tokenizer_group("prompts must be None if "
                                                 "skip_tokenizer_init is True")

            prompt_token_ids = tokenizer.encode(request_id=request_id,
                                                prompt=inputs["prompt"],
                                                lora_request=lora_request)
        else:
            prompt_token_ids = inputs["prompt_token_ids"]

535
536
537
538
539
        llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
                               prompt=inputs.get("prompt"),
                               multi_modal_data=inputs.get("multi_modal_data"))

        return self.input_processor(llm_inputs)
540

541
542
543
    def add_request(
        self,
        request_id: str,
544
        inputs: PromptInputs,
545
        params: Union[SamplingParams, PoolingParams],
546
        arrival_time: Optional[float] = None,
547
        lora_request: Optional[LoRARequest] = None,
548
        trace_headers: Optional[Dict[str, str]] = None,
549
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
550
        """Add a request to the engine's request pool.
551
552

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
553
        scheduler as `engine.step()` is called. The exact scheduling policy is
554
555
556
557
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
558
559
560
561
562
563
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
564
            arrival_time: The arrival time of the request. If None, we use
565
                the current monotonic time.
566
            trace_headers: OpenTelemetry trace headers.
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
591
        """
592
593
594
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
595
        if arrival_time is None:
596
            arrival_time = time.time()
597

598
599
600
        processed_inputs = self.process_model_inputs(request_id=request_id,
                                                     inputs=inputs,
                                                     lora_request=lora_request)
601

602
603
604
605
606
607
        self._add_processed_request(
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
608
            trace_headers=trace_headers,
609
        )
610
611
612
613
614
615

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
616
617
        arrival_time: float,
        lora_request: Optional[LoRARequest],
618
        trace_headers: Optional[Dict[str, str]] = None,
619
620
621
622
623
624
625
626
627
628
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

629
630
631
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
632

633
        sampling_params.update_from_generation_config(
634
            self.generation_config_fields, seq.eos_token_id)
635

636
        # Create the sequence group.
637
638
639
640
641
642
643
644
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
        )
645

646
647
648
649
650
651
652
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
653
654
        arrival_time: float,
        lora_request: Optional[LoRARequest],
655
656
657
658
659
660
661
662
663
664
665
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
        seq_group = SequenceGroup(request_id=request_id,
                                  seqs=[seq],
                                  arrival_time=arrival_time,
                                  lora_request=lora_request,
                                  pooling_params=pooling_params)
        return seq_group
666

Antoni Baum's avatar
Antoni Baum committed
667
668
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
669
670

        Args:
Antoni Baum's avatar
Antoni Baum committed
671
            request_id: The ID(s) of the request to abort.
672
673
674
675
676
677
678
679
680
681
682

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
683
        """
684
685
        self.scheduler.abort_seq_group(request_id)

686
687
688
689
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

690
691
692
693
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

694
    def get_num_unfinished_requests(self) -> int:
695
        """Gets the number of unfinished requests."""
696
697
        return self.scheduler.get_num_unfinished_seq_groups()

698
    def has_unfinished_requests(self) -> bool:
699
        """Returns True if there are unfinished requests."""
700
701
        return self.scheduler.has_unfinished_seqs()

702
703
704
705
706
707
708
709
710
711
712
713
    def _process_sequence_group_outputs(
        self,
        seq_group: SequenceGroup,
        outputs: List[EmbeddingSequenceGroupOutput],
    ) -> None:
        seq_group.embeddings = outputs[0].embeddings

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

714
    def _process_model_outputs(
715
        self,
716
        output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
717
        scheduled_seq_groups: List[ScheduledSequenceGroup],
718
719
        ignored_seq_groups: List[SequenceGroup],
        seq_group_metadata_list: List[SequenceGroupMetadata],
720
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
721
        """Apply the model output to the sequences in the scheduled seq groups.
722

723
724
725
        Returns RequestOutputs that can be returned to the client.
        """

726
        now = time.time()
727
728
729
730

        # Organize outputs by [sequence group][step] instead of
        # [step][sequence group].
        output_by_sequence_group = create_output_by_sequence_group(
731
            output, num_seq_groups=len(scheduled_seq_groups))
732

733
        # Update the scheduled sequence groups with the model outputs.
734
735
736
        for scheduled_seq_group, outputs, seq_group_meta in zip(
                scheduled_seq_groups, output_by_sequence_group,
                seq_group_metadata_list):
737
            seq_group = scheduled_seq_group.seq_group
738
739
            seq_group.update_num_computed_tokens(
                scheduled_seq_group.token_chunk_size)
740
741
742
            if self.model_config.embedding_mode:
                self._process_sequence_group_outputs(seq_group, outputs)
                continue
743

744
745
            self.output_processor.process_prompt_logprob(seq_group, outputs)
            if seq_group_meta.do_sample:
746
                self.output_processor.process_outputs(seq_group, outputs)
747
748
749

        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
750
751

        # Create the outputs.
752
753
        request_outputs: List[Union[RequestOutput,
                                    EmbeddingRequestOutput]] = []
754
755
        for scheduled_seq_group in scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
756
            seq_group.maybe_set_first_token_time(now)
757
            request_output = RequestOutputFactory.create(seq_group)
758
            request_outputs.append(request_output)
759
        for seq_group in ignored_seq_groups:
760
            request_output = RequestOutputFactory.create(seq_group)
761
762
763
            request_outputs.append(request_output)
        return request_outputs

764
    def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
765
766
        """Performs one decoding iteration and returns newly generated results.

767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

782
            - Step 2: Calls the distributed executor to execute the model.
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
804
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
805
806
807
808
809
810
811
812
813
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
814
        """
815
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
816

817
        if not scheduler_outputs.is_empty():
818
            execute_model_req = ExecuteModelRequest(
819
820
821
822
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
823
824
825
826
827
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
            )
            output = self.model_executor.execute_model(
                execute_model_req=execute_model_req)
828
829
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
830

831
832
        request_outputs = self._process_model_outputs(
            output, scheduler_outputs.scheduled_seq_groups,
833
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
834
835

        # Log stats.
836
        self.do_log_stats(scheduler_outputs, output)
837

838
839
840
        # Tracing
        self.do_tracing(scheduler_outputs)

841
842
843
844
845
846
847
848
        if not request_outputs:
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
            self.model_executor.stop_remote_worker_execution_loop()

849
        return request_outputs
Antoni Baum's avatar
Antoni Baum committed
850

851
852
853
854
855
856
857
858
859
860
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

861
862
863
864
    def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
865
866
        """Forced log when no requests active."""
        if self.log_stats:
867
868
            for logger in self.stat_loggers.values():
                logger.log(self._get_stats(scheduler_outputs, model_output))
869

870
871
872
873
874
875
876
877
878
879
880
881
    def _get_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs],
            model_output: Optional[List[SamplerOutput]] = None) -> Stats:
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
        """
882
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
883

884
885
886
887
888
889
890
        # System State
        #   Scheduler State
        num_running_sys = len(self.scheduler.running)
        num_swapped_sys = len(self.scheduler.swapped)
        num_waiting_sys = len(self.scheduler.waiting)

        # KV Cache Usage in %
891
        num_total_gpu = self.cache_config.num_gpu_blocks
892
893
894
895
896
        gpu_cache_usage_sys = 0.
        if num_total_gpu is not None:
            num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
            )
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
897

898
        num_total_cpu = self.cache_config.num_cpu_blocks
899
        cpu_cache_usage_sys = 0.
900
        if num_total_cpu is not None and num_total_cpu > 0:
901
902
            num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
            )
903
904
905
906
907
908
909
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
910
911
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
912
913
914
915
916
917
918
919
920
921
922
923
924

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        best_of_requests: List[int] = []
        n_requests: List[int] = []
        finished_reason_requests: List[str] = []

        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
925
        if scheduler_outputs is not None:
926
            num_generation_tokens_from_prefill_groups = 0.
927
928
929
930
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
931
932
933
934

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
935
                seq_group = scheduled_seq_group.seq_group
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
                        latency = seq_group.get_last_latency(now)
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
                    latency = seq_group.get_last_latency(now)
                    time_per_output_tokens_iter.append(latency)

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
964
                if seq_group.is_finished():
965
                    # Latency timings
966
967
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
968

969
970
971
972
973
974
975
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
976
977
978
979
                    if seq_group.sampling_params is not None:
                        best_of_requests.append(
                            seq_group.sampling_params.best_of)
                        n_requests.append(seq_group.sampling_params.n)
980
981
982
983
984
985
986
987
988
989
990
991
992
993
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
                scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
                num_generation_tokens_from_prefill_groups)
994

995
996
997
998
999
1000
1001
1002
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
        if model_output and (model_output[0].spec_decode_worker_metrics
                             is not None):
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1003
1004
        return Stats(
            now=now,
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1019
            spec_decode_metrics=spec_decode_metrics,
1020
            num_preemption_iter=num_preemption_iter,
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
            best_of_requests=best_of_requests,
            n_requests=n_requests,
            finished_reason_requests=finished_reason_requests,
1031
1032
        )

1033
    def add_lora(self, lora_request: LoRARequest) -> bool:
1034
        return self.model_executor.add_lora(lora_request)
1035
1036

    def remove_lora(self, lora_id: int) -> bool:
1037
        return self.model_executor.remove_lora(lora_id)
1038

1039
    def list_loras(self) -> Set[int]:
1040
        return self.model_executor.list_loras()
1041

1042
1043
1044
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1045
    def check_health(self) -> None:
1046
1047
        if self.tokenizer:
            self.tokenizer.check_health()
1048
        self.model_executor.check_health()
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

    def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None:
        if self.tracer is None:
            return

        for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
            # attribute names are based on
            # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
            seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
                                   self.model_config.model)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
                                   seq_group.request_id)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
                                   seq_group.sampling_params.temperature)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
                                   seq_group.sampling_params.top_p)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
                                   seq_group.sampling_params.max_tokens)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
                                   seq_group.sampling_params.best_of)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
                                   seq_group.sampling_params.n)
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
                                   seq_group.num_seqs())
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
                SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
                                   metrics.time_in_queue)
            seq_span.set_attribute(
                SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)