llm_engine.py 87.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Antoni Baum's avatar
Antoni Baum committed
4
import time
5
from collections import Counter as collectionsCounter
6
from collections import deque
7
from contextlib import contextmanager
8
from dataclasses import dataclass
9
from functools import partial
10
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
11
                    Iterable, List, Literal, Mapping, NamedTuple, Optional)
12
from typing import Sequence as GenericSequence
13
from typing import Set, Type, Union, cast
14

15
import torch
16
from typing_extensions import TypeVar
17

18
import vllm.envs as envs
19
20
21
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ObservabilityConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
22
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.engine.arg_utils import EngineArgs
24
from vllm.engine.metrics_types import StatLoggerBase, Stats
25
26
27
28
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
29
30
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
31
from vllm.executor.executor_base import ExecutorBase
32
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
33
from vllm.inputs.parse import split_enc_dec_inputs
34
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.logger import init_logger
36
from vllm.logits_process import get_bad_words_logits_processors
37
from vllm.lora.request import LoRARequest
38
from vllm.model_executor.layers.sampler import SamplerOutput
39
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
40
from vllm.multimodal.processing import EncDecMultiModalProcessor
41
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
42
43
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
44
from vllm.sampling_params import RequestOutputKind, SamplingParams
45
46
47
48
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
                           PoolingSequenceGroupOutput, Sequence, SequenceGroup,
                           SequenceGroupBase, SequenceGroupMetadata,
                           SequenceGroupOutput, SequenceStatus)
49
50
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
51
from vllm.transformers_utils.detokenizer import Detokenizer
52
from vllm.transformers_utils.tokenizer import AnyTokenizer
53
from vllm.transformers_utils.tokenizer_group import (
54
    TokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
55
56
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
57
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
58
from vllm.version import __version__ as VLLM_VERSION
59
from vllm.worker.model_runner_base import InputProcessingError
60
61

logger = init_logger(__name__)
62
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
63

64
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
65
_R = TypeVar("_R", default=Any)
66
67


68
69
70
71
72
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
73
74
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
75
76


77
78
79
80
81
82
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
83
84
85
86
87
88
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
89
90
91
    skip: List[int]


92
class SchedulerContext:
93

94
    def __init__(self, multi_step_stream_outputs: bool = False):
95
96
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
97
                                         PoolingRequestOutput]] = []
98
99
100
101
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

102
103
        self.multi_step_stream_outputs: bool = multi_step_stream_outputs

104
105
106
    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
107
108
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
109
110
111
112
113
114
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
115
                       is_first_step_output=is_first_step_output,
116
                       skip=[]))
117
118


119
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
120
    """An LLM engine that receives requests and generates texts.
121

Woosuk Kwon's avatar
Woosuk Kwon committed
122
    This is the main class for the vLLM engine. It receives requests
123
124
125
126
127
128
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

129
130
131
    The [`LLM`][vllm.LLM] class wraps this class for offline batched inference
    and the [`AsyncLLMEngine`][vllm.engine.async_llm_engine.AsyncLLMEngine]
    class wraps this class for online serving.
132

133
    The config arguments are derived from [`EngineArgs`][vllm.EngineArgs].
134
135

    Args:
136
        vllm_config: The configuration for initializing and running vLLM.
137
138
        executor_class: The model executor class for managing distributed
            execution.
139
        log_stats: Whether to log statistics.
140
        usage_context: Specified entry point, used for usage info collection.
141
    """
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

168
        return cast(_O, output)
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

192
    tokenizer: Optional[TokenizerGroup]
193

194
195
    def __init__(
        self,
196
        vllm_config: VllmConfig,
197
        executor_class: Type[ExecutorBase],
198
        log_stats: bool,
yhu422's avatar
yhu422 committed
199
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
200
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
201
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
202
        use_cached_outputs: bool = False,
203
    ) -> None:
204
205
206
207
208
209
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
210

211
        self.vllm_config = vllm_config
212
213
214
215
216
217
218
219
220
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
221
        )
222
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
223
224
        )

225
        logger.info(
226
            "Initializing a V0 LLM engine (v%s) with config: %s, "
227
            "use_cached_outputs=%s, ",
228
            VLLM_VERSION,
229
            vllm_config,
230
            use_cached_outputs,
231
        )
232

233
        self.log_stats = log_stats
234
        self.use_cached_outputs = use_cached_outputs
235

236
        if self.model_config.skip_tokenizer_init:
237
            self.tokenizer = None
238
            self.detokenizer = None
239
            tokenizer_group = None
240
241
242
243
        else:
            self.tokenizer = self._init_tokenizer()
            self.detokenizer = Detokenizer(self.tokenizer)
            tokenizer_group = self.get_tokenizer_group()
244
245
246
247
248
249
250

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
251

252
        self.seq_counter = Counter()
253
254
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
255

256
        self.input_preprocessor = InputPreprocessor(self.model_config,
257
258
                                                    self.tokenizer,
                                                    mm_registry)
259

260
        self.model_executor = executor_class(vllm_config=vllm_config)
261

262
        if self.model_config.runner_type != "pooling":
263
            self._initialize_kv_caches()
264

yhu422's avatar
yhu422 committed
265
266
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
267
268
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
269
            usage_message.report_usage(
270
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
271
272
273
274
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
275
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
276
                    "tensor_parallel_size":
277
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
278
                    "block_size":
279
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
280
                    "gpu_memory_utilization":
281
                    self.cache_config.gpu_memory_utilization,
yhu422's avatar
yhu422 committed
282
283
284

                    # Quantization
                    "quantization":
285
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
286
                    "kv_cache_dtype":
287
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
288
289
290

                    # Feature flags
                    "enable_lora":
291
                    bool(self.lora_config),
yhu422's avatar
yhu422 committed
292
                    "enable_prefix_caching":
293
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
294
                    "enforce_eager":
295
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
296
                    "disable_custom_all_reduce":
297
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
298
299
                })

300
301
302
303
304
305
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
306
307
            SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
                             multi_step_stream_outputs)
308
309
310
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

311
        if self.model_config.use_async_output_proc:
312
313
314
315
316
317
318
319
320
            process_model_outputs = weak_bind(self._process_model_outputs)

            self.async_callbacks = [
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
        else:
            self.async_callbacks = []
321
322
323

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
324
        self.process_request_outputs_callback: Optional[Callable] = None
325

326
        # Create the scheduler.
327
328
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
329
330
331
332
333
        if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
            Scheduler = resolve_obj_by_qualname(
                self.vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = self.vllm_config.scheduler_config.scheduler_cls
334
        self.scheduler = [
335
            Scheduler(
336
337
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
338
                self.async_callbacks[v_id]
339
340
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
341
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
342

343
344
        # Metric Logging.
        if self.log_stats:
345
346
347
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
348
349
350
351
352
353
354
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

355
356
357
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
358
359
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        vllm_config=vllm_config),
360
361
362
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
363
364
                        labels=dict(
                            model_name=self.model_config.served_model_name),
365
                        vllm_config=vllm_config),
366
367
368
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
369

370
371
372
373
374
375
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

376
377
378
379
380
381
382
383
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
384
                get_tokenizer_for_seq,
385
386
                stop_checker=StopChecker(self.scheduler_config.max_model_len,
                                         get_tokenizer_for_seq),
387
388
            ))

389
390
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

391
392
393
394
        # Flag to set when an input fails to process and the engine should run
        # the next step without re-scheduling.
        self._skip_scheduling_next_step = False

395
396
397
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

398
399
400
401
402
403
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
404
        start = time.time()
405
406
407
408
409
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
410
411
412
413
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
414
415
416
417
418
419
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
420
421
422
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
423

424
    @classmethod
425
    def _get_executor_cls(cls,
426
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
427
        # distributed_executor_backend must be set in VllmConfig.__post_init__
428
429
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
430
        # Initialize the cluster and specify the executor class.
431
432
433
434
435
436
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
437
438
439
440
441
442
443
444
445
446
447
448
449
        elif distributed_executor_backend == "ray":
            from vllm.executor.ray_distributed_executor import (
                RayDistributedExecutor)
            executor_class = RayDistributedExecutor
        elif distributed_executor_backend == "mp":
            from vllm.executor.mp_distributed_executor import (
                MultiprocessingDistributedExecutor)
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
            executor_class = MultiprocessingDistributedExecutor
        elif distributed_executor_backend == "uni":
            # JAX-style, single-process, multi-device executor.
450
451
            from vllm.executor.uniproc_executor import UniProcExecutor
            executor_class = UniProcExecutor
452
453
454
455
456
457
458
459
        elif distributed_executor_backend == "external_launcher":
            # executor with external launcher
            from vllm.executor.uniproc_executor import (  # noqa
                ExecutorWithExternalLauncher)
            executor_class = ExecutorWithExternalLauncher
        else:
            raise ValueError("unrecognized distributed_executor_backend: "
                             f"{distributed_executor_backend}")
460
461
        return executor_class

462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

478
479
480
481
482
483
484
485
486
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
487
488
489
490
491
492
493
494
495
        vllm_config = engine_args.create_engine_config(usage_context)

        engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            engine_cls = V1LLMEngine

        return engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
496
            usage_context=usage_context,
497
            stat_loggers=stat_loggers,
498
            disable_log_stats=engine_args.disable_log_stats,
yhu422's avatar
yhu422 committed
499
        )
500

501
502
503
504
505
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

506
507
508
509
510
511
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

512
513
    def get_tokenizer_group(self) -> TokenizerGroup:
        if self.tokenizer is None:
514
515
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
516

517
        return self.tokenizer
518

519
    def get_tokenizer(
520
521
522
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
523
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
524

525
    def _init_tokenizer(self) -> TokenizerGroup:
526
527
528
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
529
            lora_config=self.lora_config)
530

531
532
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
533
        self.cache_config.verify_with_parallel_config(self.parallel_config)
534
535
536
537
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
538

539
540
541
    def _add_processed_request(
        self,
        request_id: str,
542
        processed_inputs: ProcessorInputs,
543
544
545
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
546
        trace_headers: Optional[Mapping[str, str]] = None,
547
        priority: int = 0,
548
    ) -> Optional[SequenceGroup]:
549
550
551
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
552
553
554
555
556
557
558
559
560
561
562
563
564
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
            )
            return None

565
        self._validate_model_inputs(processed_inputs, lora_request)
566
567
568
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
569
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
570

571
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
572
573

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
574
                       lora_request)
575

576
        encoder_seq = (None if encoder_inputs is None else Sequence(
577
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
578

579
580
581
582
583
584
585
586
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
587
                trace_headers=trace_headers,
588
589
                encoder_seq=encoder_seq,
                priority=priority)
590
591
592
593
594
595
596
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
597
598
                encoder_seq=encoder_seq,
                priority=priority)
599
600
601
602
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

603
604
605
606
607
608
609
610
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

611
612
        return seq_group

613
614
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
615

616
617
618
    def add_request(
        self,
        request_id: str,
619
        prompt: PromptType,
620
        params: Union[SamplingParams, PoolingParams],
621
        arrival_time: Optional[float] = None,
622
        lora_request: Optional[LoRARequest] = None,
623
        tokenization_kwargs: Optional[dict[str, Any]] = None,
624
        trace_headers: Optional[Mapping[str, str]] = None,
625
        priority: int = 0,
626
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
627
        """Add a request to the engine's request pool.
628
629

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
630
        scheduler as `engine.step()` is called. The exact scheduling policy is
631
632
633
634
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
635
636
            prompt: The prompt to the LLM. See
                [PromptType][vllm.inputs.PromptType]
637
638
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
639
640
                [SamplingParams][vllm.SamplingParams] for text generation.
                [PoolingParams][vllm.PoolingParams] for pooling.
641
            arrival_time: The arrival time of the request. If None, we use
642
                the current monotonic time.
643
            lora_request: The LoRA request to add.
644
            trace_headers: OpenTelemetry trace headers.
645
646
            priority: The priority of the request.
                Only applicable with priority scheduling.
647
648
649
650

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
651
652
653
654
655
            - Create `n` number of [Sequence][vllm.Sequence] objects.
            - Create a [SequenceGroup][vllm.SequenceGroup] object
              from the list of [Sequence][vllm.Sequence].
            - Add the [SequenceGroup][vllm.SequenceGroup] object to the
              scheduler.
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
672
        """
673
674
675
676
        if not isinstance(request_id, str):
            raise TypeError(
                f"request_id must be a string, got {type(request_id)}")

677
678
679
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
680

681
        if priority != 0 and not self.scheduler_config.policy == "priority":
682
683
684
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

685
        if isinstance(params, SamplingParams) \
686
            and params.logits_processors \
687
688
            and self.scheduler_config.num_scheduler_steps > 1:
            raise ValueError(
689
                "Logits processors are not supported in multi-step decoding")
690

691
        if arrival_time is None:
692
            arrival_time = time.time()
693

694
695
696
697
698
699
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            seq_len = prompt["prompt_embeds"].shape[0]
            prompt["prompt_token_ids"] = [0] * seq_len

700
        processed_inputs = self.input_preprocessor.preprocess(
701
            prompt,
702
            tokenization_kwargs=tokenization_kwargs,
703
            lora_request=lora_request,
704
        )
705

706
        self._add_processed_request(
707
708
709
710
711
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
712
            trace_headers=trace_headers,
713
            priority=priority,
714
        )
715
716
717
718
719
720

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
721
722
        arrival_time: float,
        lora_request: Optional[LoRARequest],
723
        trace_headers: Optional[Mapping[str, str]] = None,
724
        encoder_seq: Optional[Sequence] = None,
725
        priority: int = 0,
726
727
728
729
730
731
732
733
734
735
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

736
737
738
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

739
740
741
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
742

743
        sampling_params.update_from_generation_config(
744
            self.generation_config_fields, seq.eos_token_id)
745

746
        # Create the sequence group.
747
748
749
750
        draft_size = 1
        if self.vllm_config.speculative_config is not None:
            draft_size = \
                self.vllm_config.speculative_config.num_speculative_tokens + 1
751
752
753
754
755
756
757
758
759
        seq_group = SequenceGroup(request_id=request_id,
                                  seqs=[seq],
                                  arrival_time=arrival_time,
                                  sampling_params=sampling_params,
                                  lora_request=lora_request,
                                  trace_headers=trace_headers,
                                  encoder_seq=encoder_seq,
                                  priority=priority,
                                  draft_size=draft_size)
760

761
762
763
764
765
766
767
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
768
769
        arrival_time: float,
        lora_request: Optional[LoRARequest],
770
        encoder_seq: Optional[Sequence] = None,
771
        priority: int = 0,
772
773
774
775
776
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
777
778
779
780
781
782
783
        seq_group = SequenceGroup(request_id=request_id,
                                  seqs=[seq],
                                  arrival_time=arrival_time,
                                  lora_request=lora_request,
                                  pooling_params=pooling_params,
                                  encoder_seq=encoder_seq,
                                  priority=priority)
784
        return seq_group
785

Antoni Baum's avatar
Antoni Baum committed
786
787
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
788
789

        Args:
Antoni Baum's avatar
Antoni Baum committed
790
            request_id: The ID(s) of the request to abort.
791
792

        Details:
793
            - Refer to [vllm.core.scheduler.Scheduler.abort_seq_group][].
794
795
796
797
798
799

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
800
        """
801
        for scheduler in self.scheduler:
802
803
            scheduler.abort_seq_group(
                request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
804

805
806
807
808
    def get_vllm_config(self) -> VllmConfig:
        """Gets the vllm configuration."""
        return self.vllm_config

809
810
811
812
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

813
814
815
816
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

817
818
819
820
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

821
822
823
824
825
826
827
828
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

829
    def get_num_unfinished_requests(self) -> int:
830
        """Gets the number of unfinished requests."""
831
832
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
833

834
    def has_unfinished_requests(self) -> bool:
835
        """Returns True if there are unfinished requests."""
836
837
838
839
840
841
842
843
844
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
845

846
847
    def reset_mm_cache(self) -> bool:
        """Reset the multi-modal cache."""
848
849
        return self.input_preprocessor.mm_registry.reset_processor_cache(
            self.model_config)
850

851
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
852
853
854
855
        """Reset prefix cache for all devices."""

        success = True
        for scheduler in self.scheduler:
856
            success = success and scheduler.reset_prefix_cache(device)
857
858
        return success

859
    @staticmethod
860
861
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
862
        outputs: List[PoolingSequenceGroupOutput],
863
    ) -> None:
864
        seq_group.pooled_data = outputs[0].data
865
866
867
868
869
870

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

871
872
873
874
875
876
877
878
    def _update_num_computed_tokens_for_multi_step_prefill(
            self, seq_group: SequenceGroup,
            seq_group_meta: SequenceGroupMetadata,
            is_first_step_output: Optional[bool]):
        """
        This function updates num_computed_tokens for prompt sequences
        when Multi-Step is enabled.

879
        seq_group: SequenceGroup to update the num_computed_tokens for.
880
        seq_group_meta: Metadata of the given SequenceGroup.
881
        is_first_step_output: Optional[bool] -
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
            When available, is_first_step_output indicates if the appended
            output token is the output of the first-step in multi-step.
            A value of None indicates that outputs from all steps in
            in multi-step are submitted in a single burst.
        """

        assert self.scheduler_config.is_multi_step

        if not seq_group_meta.is_prompt:
            # num_computed_token updates for multi-step decodes happen after
            # the tokens are appended to the sequence.
            return

        do_update: bool = False
        if self.scheduler_config.chunked_prefill_enabled:
            # In multi-step + chunked-prefill case, the prompt sequences
            # that are scheduled are fully processed in the first step.
            do_update = is_first_step_output is None or is_first_step_output
        else:
            # Normal multi-step decoding case. In this case prompt-sequences
            # are actually single-stepped. Always update in this case.
            assert seq_group.state.num_steps == 1
            do_update = True

        if do_update:
            seq_group.update_num_computed_tokens(
                seq_group_meta.token_chunk_size)

910
911
912
913
914
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
915

916
917
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
918
        """
919

920
        now = time.time()
921

922
        if len(ctx.output_queue) == 0:
923
924
            return None

925
        # Get pending async postprocessor
926
927
928
929
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
930
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
931
932
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
933
934
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
935
936
937
938
939

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

940
        has_multiple_outputs: bool = len(outputs) > 1
941
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
942
943
944
945
946
        if has_multiple_outputs:
            assert self.scheduler_config.is_multi_step or \
                     self.speculative_config
            # Organize outputs by [step][sequence group] instead of
            # [sequence group][step].
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
            if self.scheduler_config.is_multi_step:
                outputs_by_sequence_group = create_output_by_sequence_group(
                    outputs, len(seq_group_metadata_list))
            elif self.speculative_config:
                # Decodes are multi-steps while prefills are not, outputting at
                # most 1 token. Separate them so that we can trigger chunk
                # processing without having to pad or copy over prompts K times
                # to match decodes structure (costly with prompt_logprobs).
                num_prefills = sum(sg.is_prompt
                                   for sg in seq_group_metadata_list)
                prefills, decodes = outputs[:num_prefills], outputs[
                    num_prefills:]
                outputs_by_sequence_group = create_output_by_sequence_group(
                    decodes,
                    num_seq_groups=len(seq_group_metadata_list) - num_prefills)
                outputs_by_sequence_group = [p.outputs for p in prefills
                                             ] + outputs_by_sequence_group
964
965
966
            # We have outputs for multiple steps submitted in a single burst,
            # so invalidate is_first_step_output.
            is_first_step_output = None
967
968
969
        else:
            outputs_by_sequence_group = outputs

970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

987
        finished_before: List[int] = []
988
        finished_now: List[int] = []
989
990
991
992
993
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
994
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
995

996
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
997
998
999
1000
1001

            if seq_group.is_finished():
                finished_before.append(i)
                continue

1002
            output: List[SequenceGroupOutput]
1003
            if has_multiple_outputs:
1004
1005
1006
1007
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

1008
1009
1010
1011
1012
1013
1014
            if not is_async:
                if self.scheduler_config.is_multi_step:
                    # Updates happen only if the sequence is prefill
                    self._update_num_computed_tokens_for_multi_step_prefill(
                        seq_group, seq_group_meta, is_first_step_output)
                else:
                    seq_group.update_num_computed_tokens(
1015
                        seq_group_meta.token_chunk_size or 0)
1016
1017
1018

            if outputs:
                for o in outputs:
1019
1020
1021
1022
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
1023
                                o.model_forward_time or 0)
1024
1025
1026
1027
1028
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
1029
                                o.model_execute_time or 0)
1030
1031
1032
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1033

1034
            if self.model_config.runner_type == "pooling":
1035
                self._process_sequence_group_outputs(seq_group, output)
1036
1037
1038
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
1039
                    self.output_processor.process_outputs(
1040
                        seq_group, output, is_async)
1041

1042
1043
            if seq_group.is_finished():
                finished_now.append(i)
1044

1045
1046
1047
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1048

1049
1050
            seq_group = scheduled_seq_group.seq_group
            seq_group.maybe_set_first_token_time(now)
1051
1052
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1053
            request_output = RequestOutputFactory.create(
1054
1055
1056
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1057
1058
            if request_output:
                ctx.request_outputs.append(request_output)
1059

1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1072
1073
1074
1075
1076
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

1077
1078
        # For multi-step without streaming, don't create outputs each iteration
        if not is_last_step and not ctx.multi_step_stream_outputs:
1079
1080
1081
1082
            # Immediately process request outputs here (if callback is given)
            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
1083
                ctx.request_outputs.clear()
1084
1085
1086
            return

        # Create the outputs
1087
1088
        for i in indices:
            if i in skip or i in finished_before or i in finished_now:
1089
1090
                continue  # Avoids double processing

1091
1092
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1093
            seq_group = scheduled_seq_group.seq_group
1094
            seq_group.maybe_set_first_token_time(now)
1095
1096
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1097
            request_output = RequestOutputFactory.create(
1098
1099
1100
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1101
            if request_output:
1102
                ctx.request_outputs.append(request_output)
1103

1104
1105
1106
1107
1108
1109
1110
1111
        # For multi-step with streaming, create outputs each iteration
        if not is_last_step and ctx.multi_step_stream_outputs:
            # Immediately process request outputs here (if callback is given)
            if self.process_request_outputs_callback is not None:
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1112
        for seq_group in scheduler_outputs.ignored_seq_groups:
1113
1114
1115
1116
1117
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1118
            request_output = RequestOutputFactory.create(
1119
1120
1121
1122
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
1123
1124
            if request_output:
                ctx.request_outputs.append(request_output)
1125

1126
1127
1128
1129
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1130
            ctx.request_outputs.clear()
1131

1132
1133
1134
1135
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1136
            # Log stats.
1137
1138
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1139
1140

            # Tracing
1141
            self.do_tracing(scheduler_outputs, finished_before)
1142
1143
1144
1145

        return None

    def _advance_to_next_step(
1146
            self, output: SamplerOutput,
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1160
1161
1162
1163
1164
            if self.scheduler_config.is_multi_step:
                # Updates happen only if the sequence is prefill
                self._update_num_computed_tokens_for_multi_step_prefill(
                    seq_group, seq_group_metadata,
                    seq_group.state.num_steps == 1)
1165
            else:
1166
1167
1168
1169
                token_chunk_size = (seq_group_metadata.token_chunk_size
                                    if seq_group_metadata.token_chunk_size
                                    is not None else 0)
                seq_group.update_num_computed_tokens(token_chunk_size)
1170

1171
1172
1173
            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1174
                    " (i.e sampling_params.n == 1)")
1175
1176
1177
1178
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1179
1180
1181
1182

                if self.scheduler_config.is_multi_step:
                    is_prefill_append = seq.data.get_num_uncomputed_tokens(
                    ) == 0
1183
1184
                    seq.append_token_id(sample.output_token, sample.logprobs,
                                        sample.output_embed)
1185
1186
1187
                    if not is_prefill_append:
                        seq_group.update_num_computed_tokens(1)
                else:
1188
1189
                    seq.append_token_id(sample.output_token, sample.logprobs,
                                        sample.output_embed)
1190

1191
    def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1192
1193
        """Performs one decoding iteration and returns newly generated results.

1194
1195
1196
1197
        <figure markdown="span">
        ![Overview of the step function](https://i.imgur.com/sv2HssD.png)
        <figcaption>Overview of the step function</figcaption>
        </figure>
1198
1199

        Details:
1200
1201
        - Step 1: Schedules the sequences to be executed in the next
            iteration and the token blocks to be swapped in/out/copy.
1202

1203
1204
1205
1206
            - Depending on the scheduling policy,
                sequences may be `preempted/reordered`.
            - A Sequence Group (SG) refer to a group of sequences
                that are generated from the same prompt.
1207

1208
1209
        - Step 2: Calls the distributed executor to execute the model.
        - Step 3: Processes the model output. This mainly includes:
1210

1211
1212
1213
1214
            - Decodes the relevant outputs.
            - Updates the scheduled sequence groups with model outputs
                based on its `sampling parameters` (`use_beam_search` or not).
            - Frees the finished sequence groups.
1215

1216
        - Finally, it creates and returns the newly generated results.
1217
1218

        Example:
1219
1220
1221
1222
1223
1224
1225
        ```
        # Please see the example/ folder for more detailed examples.

        # initialize engine and request arguments
        engine = LLMEngine.from_engine_args(engine_args)
        example_inputs = [(0, "What is LLM?",
        SamplingParams(temperature=0.0))]
1226

1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
        # Start the engine with an event loop
        while True:
            if example_inputs:
                req_id, prompt, sampling_params = example_inputs.pop(0)
                engine.add_request(str(req_id),prompt,sampling_params)

            # continue the request processing
            request_outputs = engine.step()
            for request_output in request_outputs:
                if request_output.finished:
                    # return or show the request output

            if not (engine.has_unfinished_requests() or example_inputs):
                break
        ```
Antoni Baum's avatar
Antoni Baum committed
1242
        """
1243
1244
1245
1246
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1247

1248
        # For llm_engine, there is no pipeline parallel support, so the engine
1249
        # used is always 0.
1250
1251
        virtual_engine = 0

1252
1253
        # These are cached outputs from previous iterations. None if on first
        # iteration
1254
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1255
1256
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1257
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1258

1259
1260
        ctx = self.scheduler_contexts[virtual_engine]

1261
1262
1263
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1264
1265
1266
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
1267
1268
1269
1270
1271
        # The scheduler is also skipped if a single request caused the last
        # engine step to fail, and the previous schedule needs to be rerun.
        if not self._has_remaining_steps(
                seq_group_metadata_list
        ) and not self._skip_scheduling_next_step:
1272
            # Schedule iteration
1273
            (seq_group_metadata_list, scheduler_outputs,
1274
1275
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1276

1277
1278
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1279

1280
1281
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
1282
1283
1284
1285
1286
            # When n>1, elements in self.seq_id_to_seq_group should be deleted
            # here, otherwise memory leaks.
            for finished_request_id in finished_requests_ids:
                if finished_request_id in self.seq_id_to_seq_group:
                    del self.seq_id_to_seq_group[finished_request_id]
1287

1288
1289
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1290
                self._process_model_outputs(ctx=ctx)
1291

1292
1293
1294
1295
1296
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
1297
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
1298
                    allow_async_output_proc)
1299
1300
        else:
            finished_requests_ids = list()
1301
1302
1303

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1304

1305
        if not scheduler_outputs.is_empty():
1306
1307
1308
1309
1310
1311

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1312
                self._get_last_sampled_token_ids(virtual_engine)
1313

1314
            execute_model_req = ExecuteModelRequest(
1315
1316
1317
1318
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1319
1320
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1321
1322
1323
1324
1325
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1326
            if allow_async_output_proc:
1327
1328
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1329

1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
            try:
                outputs = self.model_executor.execute_model(
                    execute_model_req=execute_model_req)
                self._skip_scheduling_next_step = False
            except InputProcessingError as e:
                # The input for this request cannot be processed, so we must
                # abort it. If there are remaining requests in the batch that
                # have been scheduled, they will be retried on the next step.
                invalid_request_id = e.request_id
                self._abort_and_cache_schedule(
                    request_id=invalid_request_id,
                    virtual_engine=virtual_engine,
                    seq_group_metadata_list=seq_group_metadata_list,
                    scheduler_outputs=scheduler_outputs,
                    allow_async_output_proc=allow_async_output_proc)
                # Raise so the caller is notified that this request failed
                raise
1347

1348
            # We need to do this here so that last step's sampled_token_ids can
1349
1350
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
1351
                self._update_cached_scheduler_output(virtual_engine, outputs)
1352
        else:
1353
1354
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1355
1356
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1357
            # No outputs in this case
1358
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1359

1360
1361
1362
1363
1364
1365
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
1366
            # clear the cache if we have finished all the steps.
1367
1368
1369
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()

1370
1371
1372
1373
1374
1375
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1376
            # Add results to the output_queue
1377
1378
1379
1380
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1381
1382
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1383
1384
1385

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1386
                    "Async postprocessor expects only a single output set")
1387

1388
                self._advance_to_next_step(
1389
                    outputs[0], seq_group_metadata_list,
1390
                    scheduler_outputs.scheduled_seq_groups)
1391

1392
            # Check if need to run the usual non-async path
1393
            if not allow_async_output_proc:
1394
                self._process_model_outputs(ctx=ctx)
1395

1396
                # Log stats.
1397
                self.do_log_stats(scheduler_outputs, outputs)
1398

1399
1400
1401
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1402
            # Multi-step case
1403
            return ctx.request_outputs
1404

1405
        if not self.has_unfinished_requests():
1406
1407
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1408
                self._process_model_outputs(ctx=ctx)
1409
            assert len(ctx.output_queue) == 0
1410

1411
1412
1413
1414
1415
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1416
            logger.debug("Stopping remote worker execution loop.")
1417
1418
            self.model_executor.stop_remote_worker_execution_loop()

1419
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1420

1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
    def _abort_and_cache_schedule(
            self, request_id: str, virtual_engine: int,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        """Aborts a single request, and caches the scheduler outputs minus that
        request. This allows the next step to continue processing the remaining
        requests without having to re-run the scheduler."""

        # Abort the request and remove its sequence group from the current
        # schedule
        self.abort_request(request_id)
        for i, metadata in enumerate(seq_group_metadata_list):
            if metadata.request_id == request_id:
                del seq_group_metadata_list[i]
                break
        for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
            if group.seq_group.request_id == request_id:
                del scheduler_outputs.scheduled_seq_groups[i]
                break

        # If there are still other sequence groups left in the schedule, cache
        # them and flag the engine to reuse the schedule.
        if len(seq_group_metadata_list) > 0:
            self._skip_scheduling_next_step = True
            # Reuse multi-step caching logic
            self._cache_scheduler_outputs_for_multi_step(
                virtual_engine=virtual_engine,
                scheduler_outputs=scheduler_outputs,
                seq_group_metadata_list=seq_group_metadata_list,
                allow_async_output_proc=allow_async_output_proc)

1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
1468
1469
            raise AssertionError("All running sequence groups should "
                                 "have the same remaining steps.")
1470
1471
1472
1473
1474
1475

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1476
1477
1478
1479
1480
1481
1482
1483
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None

1509
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1510
1511
1512
1513
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1514
1515
1516
1517
1518
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1519
1520
1521
1522
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1523
1524
1525
1526
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1527
1528
1529
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1530
1531
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1532
1533
        """Forced log when no requests active."""
        if self.log_stats:
1534
            stats = self._get_stats(scheduler_outputs, model_output,
1535
                                    finished_before, skip)
1536
            for logger in self.stat_loggers.values():
1537
                logger.log(stats)
1538

1539
1540
1541
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1542
1543
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1544
1545
1546
1547
1548
1549
1550
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1551
1552
1553
1554
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1555
        """
1556
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1557

1558
1559
        # System State
        #   Scheduler State
1560
1561
1562
1563
1564
1565
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1566
1567

        # KV Cache Usage in %
1568
        num_total_gpu = self.cache_config.num_gpu_blocks
1569
        gpu_cache_usage_sys = 0.
1570
        if num_total_gpu:  # Guard against both None and 0
1571
1572
1573
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1574
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1575

1576
        num_total_cpu = self.cache_config.num_cpu_blocks
1577
        cpu_cache_usage_sys = 0.
1578
        if num_total_cpu:  # Guard against both None and 0
1579
1580
1581
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1582
1583
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1584
1585
1586
1587
1588
1589
1590
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
        # Exchange the uasge and cache hit stats between gpu and cpu when
        # running on cpu because the cpu_worker.py intentionally reports the
        # number of cpu blocks as gpu blocks in favor of cache management.
        if self.device_config.device_type == "cpu":
            num_total_gpu, num_total_cpu = num_total_cpu, num_total_gpu
            gpu_cache_usage_sys, cpu_cache_usage_sys = (
                cpu_cache_usage_sys,
                gpu_cache_usage_sys,
            )
            gpu_prefix_cache_hit_rate, cpu_prefix_cache_hit_rate = (
                cpu_prefix_cache_hit_rate,
                gpu_prefix_cache_hit_rate,
            )

1605
1606
1607
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1608
        num_tokens_iter = 0
1609
1610
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1611
1612
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1613
1614
1615
1616

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1617
1618
1619
1620
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1621
1622
1623
1624
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1625
        max_num_generation_tokens_requests: List[int] = []
1626
        max_tokens_requests: List[int] = []
1627
1628
        finished_reason_requests: List[str] = []

1629
        # LoRA requests
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1648
1649
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1650
        if scheduler_outputs is not None:
1651
1652
1653
1654
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1655
            num_generation_tokens_from_prefill_groups = 0
1656
1657
1658
1659
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1660
1661
1662

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1663
1664
1665
1666
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue
1667
1668
1669
1670
1671

                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1672

1673
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1674
                seq_group = scheduled_seq_group.seq_group
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
1687
                        latency = seq_group.get_last_token_latency()
1688
1689
1690
1691
1692
1693
1694
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
1695
                    latency = seq_group.get_last_token_latency()
1696
                    time_per_output_tokens_iter.append(latency)
1697
1698
1699
1700
1701
1702
1703
1704
1705
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1706
1707
1708
1709
1710
1711

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1712
                if seq_group.is_finished():
1713
                    # Latency timings
1714
1715
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1728
1729
1730
1731
1732
1733
1734
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1735
1736
1737
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1738
1739
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1740
1741
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1754
                actual_num_batched_tokens - num_prompt_tokens_iter +
1755
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1756
1757
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1758

1759
1760
        return Stats(
            now=now,
1761
1762
1763
1764
1765
1766
1767
1768
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1769
1770
1771
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1772
1773
1774
1775

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1776
            num_tokens_iter=num_tokens_iter,
1777
1778
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1779
            num_preemption_iter=num_preemption_iter,
1780
1781
1782
1783

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1784
1785
1786
1787
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1788
1789
1790
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1791
1792
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
1793
            n_requests=n_requests,
1794
            max_tokens_requests=max_tokens_requests,
1795
            finished_reason_requests=finished_reason_requests,
1796
1797
1798
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1799

1800
    def add_lora(self, lora_request: LoRARequest) -> bool:
1801
        return self.model_executor.add_lora(lora_request)
1802
1803

    def remove_lora(self, lora_id: int) -> bool:
1804
        return self.model_executor.remove_lora(lora_id)
1805

1806
    def list_loras(self) -> Set[int]:
1807
        return self.model_executor.list_loras()
1808

1809
1810
1811
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1812
1813
1814
1815
1816
1817
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

1818
1819
1820
1821
1822
    def sleep(self, level: int = 1) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.sleep(level=level)

1823
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
1824
1825
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
1826
        self.model_executor.wake_up(tags)
1827

1828
1829
1830
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

1831
    def check_health(self) -> None:
1832
        self.model_executor.check_health()
1833
1834
1835
1836

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

1837
1838
1839
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
1840
1841
1842
        if self.tracer is None:
            return

1843
1844
1845
1846
1847
1848
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
1866
1867
1868
1869
1870
1871
1872
1873

            # Handle potential None values for cancelled/aborted requests
            ttft = (metrics.first_token_time - metrics.arrival_time
                    if metrics.first_token_time is not None else None)

            e2e_time = (metrics.finished_time - metrics.arrival_time
                        if metrics.finished_time is not None else None)

1874
            seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
1875
                                   self.model_config.model)
1876
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
1877
                                   seq_group.request_id)
1878
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
1879
                                   seq_group.sampling_params.temperature)
1880
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
1881
                                   seq_group.sampling_params.top_p)
1882
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
1883
                                   seq_group.sampling_params.max_tokens)
1884
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
1885
                                   seq_group.sampling_params.n)
1886
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
1887
                                   seq_group.num_seqs())
1888
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
1889
1890
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
1891
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
1892
1893
1894
1895
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907

            # Only set timing attributes if the values are available
            if metrics.time_in_queue is not None:
                seq_span.set_attribute(
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
                    metrics.time_in_queue)
            if ttft is not None:
                seq_span.set_attribute(
                    SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            if e2e_time is not None:
                seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E,
                                       e2e_time)
1908
1909
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
1910
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
1911
1912
1913
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
1914
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
1915
1916
1917
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
1918
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
1919
                    metrics.model_execute_time)
1920

1921
    def _validate_model_inputs(self, inputs: ProcessorInputs,
1922
                               lora_request: Optional[LoRARequest]):
1923
1924
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

1925
1926
1927
1928
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs,
                                       lora_request,
                                       prompt_type="encoder")
1929

1930
1931
1932
        self._validate_model_input(decoder_inputs,
                                   lora_request,
                                   prompt_type="decoder")
1933

1934
1935
1936
1937
1938
1939
1940
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        lora_request: Optional[LoRARequest],
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
1941
1942
1943
        model_config = self.model_config
        tokenizer = (None if self.tokenizer is None else
                     self.tokenizer.get_lora_tokenizer(lora_request))
1944

1945
        prompt_ids = prompt_inputs.get("prompt_token_ids", [])
1946
1947
1948
        if not prompt_ids:
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
1949
            elif prompt_inputs["type"] == "embeds":
1950
                pass
1951
1952
1953
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")

1954
1955
1956
1957
1958
1959
        if tokenizer is not None:
            max_input_id = max(prompt_ids, default=0)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    f"Token id {max_input_id} is out of vocabulary")

1960
        max_prompt_len = self.model_config.max_model_len
1961
        if len(prompt_ids) > max_prompt_len:
1962
            if prompt_type == "encoder" and model_config.is_multimodal_model:
1963
1964
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
1965
1966
1967
                    model_config,
                    tokenizer=tokenizer or object(),  # Dummy if no tokenizer
                )
1968
                assert isinstance(mm_processor, EncDecMultiModalProcessor)
1969

1970
1971
1972
                if mm_processor.pad_dummy_encoder_prompt:
                    return  # Skip encoder length check for Whisper

1973
            if model_config.is_multimodal_model:
1974
                suggestion = (
1975
1976
1977
1978
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
1979
1980
1981
1982
1983
1984
1985
1986
1987
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
1988
1989
1990
1991

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
1992
1993
1994
1995

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
1996
1997
1998
1999
        """Constructs logits processors based on the logits_bias, and
        allowed_token_ids fields in sampling_params. Deletes those fields and
        adds the constructed logits processors to the logits_processors field.
        Returns the modified sampling params."""
2000
2001

        logits_processors = []
2002

2003
2004
2005
        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

2006
            processors = get_openai_logits_processors(
2007
2008
2009
2010
2011
2012
2013
2014
2015
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

2016
2017
2018
2019
2020
2021
        if len(sampling_params.bad_words) > 0:
            tokenizer = self.get_tokenizer(lora_request)
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

2022
2023
2024
2025
2026
2027
2028
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params
2029

2030
2031
2032
2033
2034
2035
2036
2037
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

2038

2039
2040
2041
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
    LLMEngine = V1LLMEngine  # type: ignore