llm_engine.py 89.9 KB
Newer Older
1
import copy
Antoni Baum's avatar
Antoni Baum committed
2
import time
3
from collections import Counter as collectionsCounter
4
from collections import deque
5
from contextlib import contextmanager
6
from dataclasses import dataclass
7
from functools import partial
8
9
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
                    Iterable, List, Mapping, NamedTuple, Optional)
10
from typing import Sequence as GenericSequence
11
from typing import Set, Type, Union, cast, overload
12

13
import torch
14
from typing_extensions import TypeVar, deprecated
15

16
import vllm.envs as envs
17
18
19
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ObservabilityConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
20
21
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.engine.arg_utils import EngineArgs
23
from vllm.engine.metrics_types import StatLoggerBase, Stats
24
25
26
27
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
28
29
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
30
from vllm.executor.executor_base import ExecutorBase
31
from vllm.executor.gpu_executor import GPUExecutor
32
from vllm.executor.ray_utils import initialize_ray_cluster
33
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
34
                         PromptType, SingletonInputsAdapter)
35
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
36
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.logger import init_logger
38
from vllm.logits_process import get_bad_words_logits_processors
39
from vllm.lora.request import LoRARequest
40
41
from vllm.model_executor.guided_decoding import (
    get_local_guided_decoding_logits_processor)
42
from vllm.model_executor.layers.sampler import SamplerOutput
43
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
44
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
45
46
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
47
from vllm.prompt_adapter.request import PromptAdapterRequest
48
from vllm.sampling_params import RequestOutputKind, SamplingParams
49
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
50
51
52
53
                           ParallelSampleSequenceGroup, Sequence,
                           SequenceGroup, SequenceGroupBase,
                           SequenceGroupMetadata, SequenceGroupOutput,
                           SequenceStatus)
54
55
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
56
from vllm.transformers_utils.config import try_get_generation_config
57
from vllm.transformers_utils.detokenizer import Detokenizer
58
from vllm.transformers_utils.tokenizer import AnyTokenizer
59
from vllm.transformers_utils.tokenizer_group import (
60
    BaseTokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
61
62
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
63
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
64
from vllm.version import __version__ as VLLM_VERSION
65
66

logger = init_logger(__name__)
67
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
68

69

70
71
72
73
74
75
76
77
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
    config = try_get_generation_config(
        model_config.model,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.revision,
    )

    if config is None:
78
79
        return {}

80
81
    return config.to_diff_dict()

82

83
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
84
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
85
86


87
88
89
90
91
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
92
93
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
94
95


96
97
98
99
100
101
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
102
103
104
105
106
107
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
108
109
110
    skip: List[int]


111
class SchedulerContext:
112

113
    def __init__(self, multi_step_stream_outputs: bool = False):
114
115
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
116
                                         PoolingRequestOutput]] = []
117
118
119
120
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

121
122
        self.multi_step_stream_outputs: bool = multi_step_stream_outputs

123
124
125
    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
126
127
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
128
129
130
131
132
133
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
134
                       is_first_step_output=is_first_step_output,
135
                       skip=[]))
136
137


138
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
139
    """An LLM engine that receives requests and generates texts.
140

Woosuk Kwon's avatar
Woosuk Kwon committed
141
    This is the main class for the vLLM engine. It receives requests
142
143
144
145
146
147
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

148
149
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
150

151
152
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
    :ref:`engine_args`)
153
154
155
156
157
158
159

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
160
        device_config: The configuration related to the device.
161
162
163
        lora_config (Optional): The configuration related to serving multi-LoRA.
        speculative_config (Optional): The configuration related to speculative
            decoding.
164
165
        executor_class: The model executor class for managing distributed
            execution.
166
        prompt_adapter_config (Optional): The configuration related to serving
167
            prompt adapters.
168
        log_stats: Whether to log statistics.
169
        usage_context: Specified entry point, used for usage info collection.
170
    """
171

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

197
        return cast(_O, output)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

223
224
    def __init__(
        self,
225
        vllm_config: VllmConfig,
226
        executor_class: Type[ExecutorBase],
227
        log_stats: bool,
yhu422's avatar
yhu422 committed
228
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
229
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
230
        input_registry: InputRegistry = INPUT_REGISTRY,
231
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
232
        use_cached_outputs: bool = False,
233
    ) -> None:
234

235
236
237
238
239
240
241
242
243
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
244
        )
245
246
        self.prompt_adapter_config = vllm_config.prompt_adapter_config  # noqa
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
247
248
        )

249
        logger.info(
250
251
252
            "Initializing an LLM engine (v%s) with config: "
            "model=%r, speculative_config=%r, tokenizer=%r, "
            "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
253
            "override_neuron_config=%s, tokenizer_revision=%s, "
254
255
            "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
            "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
256
            "pipeline_parallel_size=%d, "
257
258
            "disable_custom_all_reduce=%s, quantization=%s, "
            "enforce_eager=%s, kv_cache_dtype=%s, "
259
            "quantization_param_path=%s, device_config=%s, "
260
            "decoding_config=%r, observability_config=%r, "
261
            "seed=%d, served_model_name=%s, "
262
263
264
            "num_scheduler_steps=%d, chunked_prefill_enabled=%s "
            "multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
            "use_async_output_proc=%s, use_cached_outputs=%s, "
265
266
            "mm_processor_kwargs=%s, pooler_config=%r,"
            "compilation_config=%r",
267
            VLLM_VERSION,
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
            self.model_config.model,
            self.speculative_config,
            self.model_config.tokenizer,
            self.model_config.skip_tokenizer_init,
            self.model_config.tokenizer_mode,
            self.model_config.revision,
            self.model_config.override_neuron_config,
            self.model_config.tokenizer_revision,
            self.model_config.trust_remote_code,
            self.model_config.dtype,
            self.model_config.max_model_len,
            self.load_config.download_dir,
            self.load_config.load_format,
            self.parallel_config.tensor_parallel_size,
            self.parallel_config.pipeline_parallel_size,
            self.parallel_config.disable_custom_all_reduce,
            self.model_config.quantization,
            self.model_config.enforce_eager,
            self.cache_config.cache_dtype,
            self.model_config.quantization_param_path,
            self.device_config.device,
            self.decoding_config,
            self.observability_config,
            self.model_config.seed,
            self.model_config.served_model_name,
            self.scheduler_config.num_scheduler_steps,
            self.scheduler_config.chunked_prefill_enabled,
            self.scheduler_config.multi_step_stream_outputs,
            self.cache_config.enable_prefix_caching,
            self.model_config.use_async_output_proc,
298
            use_cached_outputs,
299
300
            self.model_config.mm_processor_kwargs,
            self.model_config.pooler_config,
301
            vllm_config.compilation_config,
302
        )
303
        # TODO(woosuk): Print more configs in debug mode.
304

305
        self.log_stats = log_stats
306
        self.use_cached_outputs = use_cached_outputs
307

308
        if not self.model_config.skip_tokenizer_init:
309
            self.tokenizer = self._init_tokenizer()
310
            self.detokenizer = Detokenizer(self.tokenizer)
311
            tokenizer_group = self.get_tokenizer_group()
312
313
        else:
            self.tokenizer = None
314
            self.detokenizer = None
315
316
317
318
319
320
321
322
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
323

324
        self.seq_counter = Counter()
325
        self.generation_config_fields = _load_generation_config_dict(
326
            self.model_config)
327

328
        self.input_preprocessor = InputPreprocessor(self.model_config,
329
330
                                                    self.tokenizer,
                                                    mm_registry)
331

332
333
        self.input_registry = input_registry
        self.input_processor = input_registry.create_input_processor(
334
            self.model_config)
335

336
        self.model_executor = executor_class(vllm_config=vllm_config, )
337

338
        if self.model_config.task != "embedding":
339
            self._initialize_kv_caches()
340

yhu422's avatar
yhu422 committed
341
342
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
343
344
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
345
            usage_message.report_usage(
346
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
347
348
349
350
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
351
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
352
                    "tensor_parallel_size":
353
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
354
                    "block_size":
355
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
356
                    "gpu_memory_utilization":
357
                    self.cache_config.gpu_memory_utilization,
yhu422's avatar
yhu422 committed
358
359
360

                    # Quantization
                    "quantization":
361
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
362
                    "kv_cache_dtype":
363
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
364
365
366

                    # Feature flags
                    "enable_lora":
367
                    bool(self.lora_config),
368
                    "enable_prompt_adapter":
369
                    bool(self.prompt_adapter_config),
yhu422's avatar
yhu422 committed
370
                    "enable_prefix_caching":
371
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
372
                    "enforce_eager":
373
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
374
                    "disable_custom_all_reduce":
375
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
376
377
                })

378
379
380
381
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
382

383
384
385
386
387
388
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
389
390
            SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
                             multi_step_stream_outputs)
391
392
393
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

394
        if self.model_config.use_async_output_proc:
395
396
397
398
399
400
401
402
403
            process_model_outputs = weak_bind(self._process_model_outputs)

            self.async_callbacks = [
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
        else:
            self.async_callbacks = []
404
405
406

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
407
        self.process_request_outputs_callback: Optional[Callable] = None
408

409
        # Create the scheduler.
410
411
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
412
        self.scheduler = [
413
            Scheduler(
414
415
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
416
                self.async_callbacks[v_id]
417
418
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
419
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
420

421
422
        # Metric Logging.
        if self.log_stats:
423
424
425
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
426
427
428
429
430
431
432
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

433
434
435
436
437
438
439
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
440
441
                        labels=dict(
                            model_name=self.model_config.served_model_name),
442
443
444
445
                        max_model_len=self.model_config.max_model_len),
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
446

447
448
449
450
451
452
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

453
454
455
456
457
458
459
460
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
461
                get_tokenizer_for_seq,
462
463
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
464
                    get_tokenizer_for_seq,
465
466
467
                ),
            ))

468
469
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

470
471
472
473
474
475
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
476
        start = time.time()
477
478
479
480
481
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
482
483
484
485
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
486
487
488
489
490
491
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
492
493
494
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
495

496
    @classmethod
497
    def _get_executor_cls(cls,
498
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
499
500
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
501
        # Initialize the cluster and specify the executor class.
502
503
504
505
506
507
508
509
510
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
511
512
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
513
        elif engine_config.device_config.device_type == "tpu":
514
515
516
517
518
519
520
521
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutor
                executor_class = RayTPUExecutor
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutor
                executor_class = TPUExecutor
522
        elif engine_config.device_config.device_type == "cpu":
523
524
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
525
526
527
528
529
530
531
532
        elif engine_config.device_config.device_type == "hpu":
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_hpu_executor import RayHPUExecutor
                executor_class = RayHPUExecutor
            else:
                from vllm.executor.hpu_executor import HPUExecutor
                executor_class = HPUExecutor
533
534
535
        elif engine_config.device_config.device_type == "openvino":
            from vllm.executor.openvino_executor import OpenVINOExecutor
            executor_class = OpenVINOExecutor
536
537
538
539
540
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutor
                executor_class = RayXPUExecutor
541
542
543
544
545
546
547
            elif distributed_executor_backend == "mp":
                # FIXME(kunshang):
                # spawn needs calling `if __name__ == '__main__':``
                # fork is not supported for xpu start new process.
                logger.error(
                    "Both start methods (spawn and fork) have issue "
                    "on XPU if you use mp backend, Please try ray instead.")
548
549
550
            else:
                from vllm.executor.xpu_executor import XPUExecutor
                executor_class = XPUExecutor
551
        elif distributed_executor_backend == "ray":
552
            initialize_ray_cluster(engine_config.parallel_config)
553
554
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
555
556
557
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutor)
558
559
560
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
561
            executor_class = MultiprocessingGPUExecutor
562
563
564
        else:
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor
565
566
567
568
569
570
571
572
573
574
575
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
576
        engine_config = engine_args.create_engine_config(usage_context)
577
        executor_class = cls._get_executor_cls(engine_config)
578
        # Create the LLM engine.
yhu422's avatar
yhu422 committed
579
        engine = cls(
580
            vllm_config=engine_config,
yhu422's avatar
yhu422 committed
581
582
583
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
584
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
585
        )
586

587
        return engine
588

589
590
591
592
593
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

594
595
596
597
598
599
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

600
    def get_tokenizer_group(
601
602
603
604
605
606
        self,
        group_type: Type[_G] = BaseTokenizerGroup,
    ) -> _G:
        tokenizer_group = self.tokenizer

        if tokenizer_group is None:
607
608
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
609
610
611
612
        if not isinstance(tokenizer_group, group_type):
            raise TypeError("Invalid type of tokenizer group. "
                            f"Expected type: {group_type}, but "
                            f"found type: {type(tokenizer_group)}")
613

614
        return tokenizer_group
615

616
    def get_tokenizer(
617
618
619
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
620
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
621

622
623
624
625
626
    def _init_tokenizer(self) -> BaseTokenizerGroup:
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
            parallel_config=self.parallel_config,
627
            lora_config=self.lora_config)
628

629
630
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
631
        self.cache_config.verify_with_parallel_config(self.parallel_config)
632
633
634
635
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
636
637
638
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
639

640
641
642
    def _add_processed_request(
        self,
        request_id: str,
643
        processed_inputs: ProcessorInputs,
644
645
646
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
647
        prompt_adapter_request: Optional[PromptAdapterRequest],
648
        trace_headers: Optional[Mapping[str, str]] = None,
649
        priority: int = 0,
650
    ) -> Optional[SequenceGroup]:
651
652
653
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
654
655
656
657
658
659
660
661
662
663
664
665
666
667
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
            )
            return None

668
        self._validate_model_inputs(processed_inputs, lora_request)
669
670
671
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
672
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
673

674
675
676
677
678
679
680
681
        if is_encoder_decoder_inputs(processed_inputs):
            decoder_inputs = processed_inputs["decoder"]
            encoder_inputs = processed_inputs["encoder"]
        else:
            decoder_inputs = processed_inputs
            encoder_inputs = None

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
682
                       lora_request, prompt_adapter_request)
683

684
685
686
        encoder_seq = (None if encoder_inputs is None else Sequence(
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
            prompt_adapter_request))
687

688
689
690
691
692
693
694
695
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
696
                trace_headers=trace_headers,
697
                prompt_adapter_request=prompt_adapter_request,
698
699
                encoder_seq=encoder_seq,
                priority=priority)
700
701
702
703
704
705
706
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
707
                prompt_adapter_request=prompt_adapter_request,
708
709
                encoder_seq=encoder_seq,
                priority=priority)
710
711
712
713
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

714
715
716
717
718
719
720
721
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

722
723
        return seq_group

724
725
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
726

727
728
    @overload
    @deprecated("'inputs' will be renamed to 'prompt")
729
730
731
    def add_request(
        self,
        request_id: str,
732
733
        *,
        inputs: PromptType,
734
        params: Union[SamplingParams, PoolingParams],
735
        arrival_time: Optional[float] = None,
736
        lora_request: Optional[LoRARequest] = None,
737
        trace_headers: Optional[Mapping[str, str]] = None,
738
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
739
        priority: int = 0,
740
    ) -> None:
741
742
743
744
745
746
747
748
749
750
751
752
753
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
754
    ) -> None:
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    def add_request(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
            priority: int = 0,
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
773
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
774
        """Add a request to the engine's request pool.
775
776

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
777
        scheduler as `engine.step()` is called. The exact scheduling policy is
778
779
780
781
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
782
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
783
784
785
786
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
787
            arrival_time: The arrival time of the request. If None, we use
788
                the current monotonic time.
789
            trace_headers: OpenTelemetry trace headers.
790
791
            priority: The priority of the request.
                Only applicable with priority scheduling.
792
793
794
795

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
796
            - Create `n` number of :class:`~vllm.Sequence` objects.
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
816
        """
817
818
819
820
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

821
822
823
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
824

825
        if priority != 0 and not self.scheduler_config.policy == "priority":
826
827
828
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

829
830
831
832
833
834
835
        if isinstance(params, SamplingParams) \
            and (params.guided_decoding or params.logits_processors) \
            and self.scheduler_config.num_scheduler_steps > 1:
            raise ValueError(
                "Guided decoding and logits processors are not supported "
                "in multi-step decoding")

836
        if arrival_time is None:
837
            arrival_time = time.time()
838

839
840
841
842
843
        if self.tokenizer is not None:
            self._validate_token_prompt(
                prompt,
                tokenizer=self.get_tokenizer(lora_request=lora_request))

844
        preprocessed_inputs = self.input_preprocessor.preprocess(
845
            prompt,
846
847
            request_id=request_id,
            lora_request=lora_request,
848
849
            prompt_adapter_request=prompt_adapter_request,
        )
850
        processed_inputs = self.input_processor(preprocessed_inputs)
851

852
        self._add_processed_request(
853
854
855
856
857
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
858
            prompt_adapter_request=prompt_adapter_request,
859
            trace_headers=trace_headers,
860
            priority=priority,
861
        )
862

863
864
865
866
867
868
869
870
871
872
873
    def _validate_token_prompt(self, prompt: PromptType,
                               tokenizer: AnyTokenizer):
        # Guard against out-of-vocab tokens.
        # For some tokenizers, tokenizer.decode will happily return empty text
        # for token ids that are out of vocab, and we don't detect token ids
        # that are greater than the max token id before running the model.
        # However, these token ids will later crash a cuda kernel at runtime
        # with an index out of bounds error. This will crash the entire engine.
        # This needs to happen before multimodal input pre-processing, which
        # may add dummy <image> tokens that aren't part of the tokenizer's
        # vocabulary.
874
        if is_token_prompt(prompt):
875
876
877
878
879
880
881
882
883
            prompt_ids = prompt["prompt_token_ids"]
            if len(prompt_ids) == 0:
                # Empty prompt check is handled later
                return
            max_input_id = max(prompt_ids)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    "Token id {} is out of vocabulary".format(max_input_id))

884
885
886
887
888
    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
889
890
        arrival_time: float,
        lora_request: Optional[LoRARequest],
891
        trace_headers: Optional[Mapping[str, str]] = None,
892
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
893
        encoder_seq: Optional[Sequence] = None,
894
        priority: int = 0,
895
896
897
898
899
900
901
902
903
904
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

905
906
907
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

908
909
910
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
911

912
        sampling_params.update_from_generation_config(
913
            self.generation_config_fields, seq.eos_token_id)
914

915
        # Create the sequence group.
916
917
918
919
920
921
922
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
923
            prompt_adapter_request=prompt_adapter_request,
924
925
            encoder_seq=encoder_seq,
            priority=priority)
926

927
928
929
930
931
932
933
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
934
935
        arrival_time: float,
        lora_request: Optional[LoRARequest],
936
        prompt_adapter_request: Optional[PromptAdapterRequest],
937
        encoder_seq: Optional[Sequence] = None,
938
        priority: int = 0,
939
940
941
942
943
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
944
945
946
947
948
949
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
950
            prompt_adapter_request=prompt_adapter_request,
951
952
            encoder_seq=encoder_seq,
            priority=priority)
953
        return seq_group
954

Antoni Baum's avatar
Antoni Baum committed
955
956
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
957
958

        Args:
Antoni Baum's avatar
Antoni Baum committed
959
            request_id: The ID(s) of the request to abort.
960
961
962
963
964
965
966
967
968
969
970

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
971
        """
972
973
        for scheduler in self.scheduler:
            scheduler.abort_seq_group(request_id)
974

975
976
977
978
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

979
980
981
982
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

983
984
985
986
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

987
988
989
990
991
992
993
994
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

995
    def get_num_unfinished_requests(self) -> int:
996
        """Gets the number of unfinished requests."""
997
998
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
999

1000
    def has_unfinished_requests(self) -> bool:
1001
        """Returns True if there are unfinished requests."""
1002
1003
1004
1005
1006
1007
1008
1009
1010
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
1011

1012
    @staticmethod
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
        outputs: List[EmbeddingSequenceGroupOutput],
    ) -> None:
        seq_group.embeddings = outputs[0].embeddings

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

1024
1025
1026
1027
1028
1029
1030
1031
    def _update_num_computed_tokens_for_multi_step_prefill(
            self, seq_group: SequenceGroup,
            seq_group_meta: SequenceGroupMetadata,
            is_first_step_output: Optional[bool]):
        """
        This function updates num_computed_tokens for prompt sequences
        when Multi-Step is enabled.

1032
        seq_group: SequenceGroup to update the num_computed_tokens for.
1033
        seq_group_meta: Metadata of the given SequenceGroup.
1034
        is_first_step_output: Optional[bool] -
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
            When available, is_first_step_output indicates if the appended
            output token is the output of the first-step in multi-step.
            A value of None indicates that outputs from all steps in
            in multi-step are submitted in a single burst.
        """

        assert self.scheduler_config.is_multi_step

        if not seq_group_meta.is_prompt:
            # num_computed_token updates for multi-step decodes happen after
            # the tokens are appended to the sequence.
            return

        do_update: bool = False
        if self.scheduler_config.chunked_prefill_enabled:
            # In multi-step + chunked-prefill case, the prompt sequences
            # that are scheduled are fully processed in the first step.
            do_update = is_first_step_output is None or is_first_step_output
        else:
            # Normal multi-step decoding case. In this case prompt-sequences
            # are actually single-stepped. Always update in this case.
            assert seq_group.state.num_steps == 1
            do_update = True

        if do_update:
            seq_group.update_num_computed_tokens(
                seq_group_meta.token_chunk_size)

1063
1064
1065
1066
1067
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
1068

1069
1070
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
1071
        """
1072

1073
        now = time.time()
1074

1075
        if len(ctx.output_queue) == 0:
1076
1077
            return None

1078
        # Get pending async postprocessor
1079
1080
1081
1082
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1083
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
1084
1085
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1086
1087
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
1088
1089
1090
1091
1092

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

1093
        has_multiple_outputs: bool = len(outputs) > 1
1094
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
1095
1096
1097
1098
1099
        if has_multiple_outputs:
            assert self.scheduler_config.is_multi_step or \
                     self.speculative_config
            # Organize outputs by [step][sequence group] instead of
            # [sequence group][step].
1100
1101
            outputs_by_sequence_group = create_output_by_sequence_group(
                outputs, num_seq_groups=len(seq_group_metadata_list))
1102
1103
1104
            # We have outputs for multiple steps submitted in a single burst,
            # so invalidate is_first_step_output.
            is_first_step_output = None
1105
1106
1107
        else:
            outputs_by_sequence_group = outputs

1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

1125
        finished_before: List[int] = []
1126
        finished_now: List[int] = []
1127
1128
1129
1130
1131
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
1132
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1133

1134
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
1135
1136
1137
1138
1139

            if seq_group.is_finished():
                finished_before.append(i)
                continue

1140
            output: List[SequenceGroupOutput]
1141
            if has_multiple_outputs:
1142
1143
1144
1145
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

1146
1147
1148
1149
1150
1151
1152
            if not is_async:
                if self.scheduler_config.is_multi_step:
                    # Updates happen only if the sequence is prefill
                    self._update_num_computed_tokens_for_multi_step_prefill(
                        seq_group, seq_group_meta, is_first_step_output)
                else:
                    seq_group.update_num_computed_tokens(
1153
                        seq_group_meta.token_chunk_size or 0)
1154
1155
1156

            if outputs:
                for o in outputs:
1157
1158
1159
1160
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
1161
                                o.model_forward_time or 0)
1162
1163
1164
1165
1166
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
1167
                                o.model_execute_time or 0)
1168
1169
1170
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1171

1172
            if self.model_config.task == "embedding":
1173
                self._process_sequence_group_outputs(seq_group, output)
1174
1175
1176
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
1177
                    self.output_processor.process_outputs(
1178
                        seq_group, output, is_async)
1179

1180
1181
            if seq_group.is_finished():
                finished_now.append(i)
1182

1183
1184
1185
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1186

1187
1188
            seq_group = scheduled_seq_group.seq_group
            seq_group.maybe_set_first_token_time(now)
1189
            request_output = RequestOutputFactory.create(
1190
1191
1192
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1193
1194
            if request_output:
                ctx.request_outputs.append(request_output)
1195

1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1208
1209
1210
1211
1212
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

1213
1214
        # For multi-step without streaming, don't create outputs each iteration
        if not is_last_step and not ctx.multi_step_stream_outputs:
1215
1216
1217
1218
            # Immediately process request outputs here (if callback is given)
            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
1219
                ctx.request_outputs.clear()
1220
1221
1222
            return

        # Create the outputs
1223
1224
        for i in indices:
            if i in skip or i in finished_before or i in finished_now:
1225
1226
                continue  # Avoids double processing

1227
1228
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1229
            seq_group = scheduled_seq_group.seq_group
1230
            seq_group.maybe_set_first_token_time(now)
1231
            request_output = RequestOutputFactory.create(
1232
1233
1234
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1235
            if request_output:
1236
                ctx.request_outputs.append(request_output)
1237

1238
1239
1240
1241
1242
1243
1244
1245
        # For multi-step with streaming, create outputs each iteration
        if not is_last_step and ctx.multi_step_stream_outputs:
            # Immediately process request outputs here (if callback is given)
            if self.process_request_outputs_callback is not None:
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1246
        for seq_group in scheduler_outputs.ignored_seq_groups:
1247
1248
1249
1250
1251
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1252
            request_output = RequestOutputFactory.create(
1253
1254
1255
1256
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
1257
1258
            if request_output:
                ctx.request_outputs.append(request_output)
1259

1260
1261
1262
1263
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1264
            ctx.request_outputs.clear()
1265

1266
1267
1268
1269
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1270
            # Log stats.
1271
1272
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1273
1274

            # Tracing
1275
            self.do_tracing(scheduler_outputs, finished_before)
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293

        return None

    def _advance_to_next_step(
            self, output: List[SamplerOutput],
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1294
1295
1296
1297
1298
            if self.scheduler_config.is_multi_step:
                # Updates happen only if the sequence is prefill
                self._update_num_computed_tokens_for_multi_step_prefill(
                    seq_group, seq_group_metadata,
                    seq_group.state.num_steps == 1)
1299
            else:
1300
1301
1302
1303
                token_chunk_size = (seq_group_metadata.token_chunk_size
                                    if seq_group_metadata.token_chunk_size
                                    is not None else 0)
                seq_group.update_num_computed_tokens(token_chunk_size)
1304

1305
1306
1307
            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1308
                    " (i.e sampling_params.n == 1)")
1309
1310
1311
1312
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1313
1314
1315
1316
1317
1318
1319
1320
1321

                if self.scheduler_config.is_multi_step:
                    is_prefill_append = seq.data.get_num_uncomputed_tokens(
                    ) == 0
                    seq.append_token_id(sample.output_token, sample.logprobs)
                    if not is_prefill_append:
                        seq_group.update_num_computed_tokens(1)
                else:
                    seq.append_token_id(sample.output_token, sample.logprobs)
1322

1323
    def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1324
1325
        """Performs one decoding iteration and returns newly generated results.

1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

1341
            - Step 2: Calls the distributed executor to execute the model.
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
1363
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
1364
1365
1366
1367
1368
1369
1370
1371
1372
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
1373
        """
1374
1375
1376
1377
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1378

1379
        # For llm_engine, there is no pipeline parallel support, so the engine
1380
        # used is always 0.
1381
1382
        virtual_engine = 0

1383
1384
        # These are cached outputs from previous iterations. None if on first
        # iteration
1385
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1386
1387
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1388
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1389

1390
1391
        ctx = self.scheduler_contexts[virtual_engine]

1392
1393
1394
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1395
1396
1397
1398
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
1399
            # Schedule iteration
1400
            (seq_group_metadata_list, scheduler_outputs,
1401
1402
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1403

1404
1405
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1406

1407
1408
1409
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()

1410
1411
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1412
                self._process_model_outputs(ctx=ctx)
1413

1414
1415
1416
1417
1418
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
1419
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
1420
                    allow_async_output_proc)
1421
1422
        else:
            finished_requests_ids = list()
1423
1424
1425

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1426

1427
        if not scheduler_outputs.is_empty():
1428
1429
1430
1431
1432
1433

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1434
                self._get_last_sampled_token_ids(virtual_engine)
1435

1436
            execute_model_req = ExecuteModelRequest(
1437
1438
1439
1440
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1441
1442
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1443
1444
1445
1446
1447
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1448
            if allow_async_output_proc:
1449
1450
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1451

1452
            outputs = self.model_executor.execute_model(
1453
                execute_model_req=execute_model_req)
1454

1455
            # We need to do this here so that last step's sampled_token_ids can
1456
1457
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
1458
                self._update_cached_scheduler_output(virtual_engine, outputs)
1459
        else:
1460
1461
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1462
1463
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1464
            # No outputs in this case
1465
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1466

1467
1468
1469
1470
1471
1472
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
1473
            # clear the cache if we have finished all the steps.
1474
1475
1476
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()

1477
1478
1479
1480
1481
1482
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1483
            # Add results to the output_queue
1484
1485
1486
1487
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1488
1489
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1490
1491
1492

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1493
                    "Async postprocessor expects only a single output set")
1494

1495
                self._advance_to_next_step(
1496
                    outputs[0], seq_group_metadata_list,
1497
                    scheduler_outputs.scheduled_seq_groups)
1498

1499
            # Check if need to run the usual non-async path
1500
            if not allow_async_output_proc:
1501
                self._process_model_outputs(ctx=ctx)
1502

1503
                # Log stats.
1504
                self.do_log_stats(scheduler_outputs, outputs)
1505

1506
1507
1508
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1509
            # Multi-step case
1510
            return ctx.request_outputs
1511

1512
        if not self.has_unfinished_requests():
1513
1514
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1515
                self._process_model_outputs(ctx=ctx)
1516
            assert len(ctx.output_queue) == 0
1517

1518
1519
1520
1521
1522
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1523
            logger.debug("Stopping remote worker execution loop.")
1524
1525
            self.model_executor.stop_remote_worker_execution_loop()

1526
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1527

1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
1543
1544
            raise AssertionError("All running sequence groups should "
                                 "have the same remaining steps.")
1545
1546
1547
1548
1549
1550

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1551
1552
1553
1554
1555
1556
1557
1558
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None

1584
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1585
1586
1587
1588
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1589
1590
1591
1592
1593
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1594
1595
1596
1597
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1598
1599
1600
1601
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1602
1603
1604
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1605
1606
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1607
1608
        """Forced log when no requests active."""
        if self.log_stats:
1609
            stats = self._get_stats(scheduler_outputs, model_output,
1610
                                    finished_before, skip)
1611
            for logger in self.stat_loggers.values():
1612
                logger.log(stats)
1613

1614
1615
1616
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1617
1618
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1619
1620
1621
1622
1623
1624
1625
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1626
1627
1628
1629
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1630
        """
1631
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1632

1633
1634
        # System State
        #   Scheduler State
1635
1636
1637
1638
1639
1640
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1641
1642

        # KV Cache Usage in %
1643
        num_total_gpu = self.cache_config.num_gpu_blocks
1644
        gpu_cache_usage_sys = 0.
1645
        if num_total_gpu:  # Guard against both None and 0
1646
1647
1648
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1649
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1650

1651
        num_total_cpu = self.cache_config.num_cpu_blocks
1652
        cpu_cache_usage_sys = 0.
1653
        if num_total_cpu:  # Guard against both None and 0
1654
1655
1656
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1657
1658
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1659
1660
1661
1662
1663
1664
1665
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1666
1667
1668
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1669
        num_tokens_iter = 0
1670
1671
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1672
1673
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1674
1675
1676
1677

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1678
1679
1680
1681
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1682
1683
1684
        time_in_queue_requests: List[float] = []
        model_forward_time_requests: List[float] = []
        model_execute_time_requests: List[float] = []
1685
1686
1687
1688
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1689
        max_num_generation_tokens_requests: List[int] = []
1690
        max_tokens_requests: List[int] = []
1691
1692
        finished_reason_requests: List[str] = []

1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
        # Lora requests
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1712
1713
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1714
        if scheduler_outputs is not None:
1715
1716
1717
1718
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1719
            num_generation_tokens_from_prefill_groups = 0
1720
1721
1722
1723
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1724
1725
1726

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1727
1728
1729
1730
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue
1731
1732
1733
1734
1735

                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1736

1737
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1738
                seq_group = scheduled_seq_group.seq_group
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
                        latency = seq_group.get_last_latency(now)
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
                    latency = seq_group.get_last_latency(now)
                    time_per_output_tokens_iter.append(latency)
1761
1762
1763
1764
1765
1766
1767
1768
1769
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1770
1771
1772
1773
1774
1775

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1776
                if seq_group.is_finished():
1777
                    # Latency timings
1778
1779
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1792
1793
1794
1795
1796
1797
1798
1799
1800
                    if seq_group.metrics.time_in_queue is not None:
                        time_in_queue_requests.append(
                            seq_group.metrics.time_in_queue)
                    if seq_group.metrics.model_forward_time is not None:
                        model_forward_time_requests.append(
                            seq_group.metrics.model_forward_time)
                    if seq_group.metrics.model_execute_time is not None:
                        model_execute_time_requests.append(
                            seq_group.metrics.model_execute_time * 1000)
1801
1802
1803
1804
1805
1806
1807
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1808
1809
1810
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1811
1812
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1813
1814
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1827
                actual_num_batched_tokens - num_prompt_tokens_iter +
1828
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1829
1830
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1831
1832
1833
1834
1835
1836
1837
1838
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
        if model_output and (model_output[0].spec_decode_worker_metrics
                             is not None):
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1839
1840
        return Stats(
            now=now,
1841
1842
1843
1844
1845
1846
1847
1848
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1849
1850
1851
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1852
1853
1854
1855

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1856
            num_tokens_iter=num_tokens_iter,
1857
1858
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1859
            spec_decode_metrics=spec_decode_metrics,
1860
            num_preemption_iter=num_preemption_iter,
1861
1862
1863
1864

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1865
1866
1867
1868
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1869
1870
1871
            time_in_queue_requests=time_in_queue_requests,
            model_forward_time_requests=model_forward_time_requests,
            model_execute_time_requests=model_execute_time_requests,
1872
1873
1874
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1875
1876
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
1877
            n_requests=n_requests,
1878
            max_tokens_requests=max_tokens_requests,
1879
            finished_reason_requests=finished_reason_requests,
1880
1881
1882
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1883

1884
    def add_lora(self, lora_request: LoRARequest) -> bool:
1885
        return self.model_executor.add_lora(lora_request)
1886
1887

    def remove_lora(self, lora_id: int) -> bool:
1888
        return self.model_executor.remove_lora(lora_id)
1889

1890
    def list_loras(self) -> Set[int]:
1891
        return self.model_executor.list_loras()
1892

1893
1894
1895
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

1906
    def check_health(self) -> None:
1907
1908
        if self.tokenizer:
            self.tokenizer.check_health()
1909
        self.model_executor.check_health()
1910

1911
    def start_profile(self) -> None:
1912
1913
        # using type instead of isinstance to check to avoid capturing
        # inherited classes (MultiprocessingGPUExecutor)
1914
        if type(self.model_executor) == GPUExecutor:  # noqa: E721
1915
1916
1917
            self.model_executor.start_profile()
        else:
            self.model_executor._run_workers("start_profile")
1918
1919

    def stop_profile(self) -> None:
1920
1921
        # using type instead of isinstance to check to avoid capturing
        # inherited classes (MultiprocessingGPUExecutor)
1922
        if type(self.model_executor) == GPUExecutor:  # noqa: E721
1923
1924
1925
            self.model_executor.stop_profile()
        else:
            self.model_executor._run_workers("stop_profile")
1926

1927
1928
1929
    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

1930
1931
1932
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
1933
1934
1935
        if self.tracer is None:
            return

1936
1937
1938
1939
1940
1941
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
            # attribute names are based on
            # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
            seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
                                   self.model_config.model)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
                                   seq_group.request_id)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
                                   seq_group.sampling_params.temperature)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
                                   seq_group.sampling_params.top_p)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
                                   seq_group.sampling_params.max_tokens)
            seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
                                   seq_group.sampling_params.n)
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
                                   seq_group.num_seqs())
            seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
                SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
                                   metrics.time_in_queue)
            seq_span.set_attribute(
                SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER,
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD,
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
                    metrics.model_execute_time)
2002

2003
    def _validate_model_inputs(self, inputs: ProcessorInputs,
2004
                               lora_request: Optional[LoRARequest]):
2005
        if is_encoder_decoder_inputs(inputs):
2006
2007
            # For encoder-decoder multimodal models, the max_prompt_len
            # restricts the decoder prompt length
2008
2009
            prompt_inputs = inputs["decoder" if self.model_config.
                                   is_multimodal_model else "encoder"]
2010
        else:
2011
2012
            prompt_inputs = inputs

2013
        prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
2014
2015

        if prompt_ids is None or len(prompt_ids) == 0:
2016
            raise ValueError("Prompt cannot be empty")
2017

2018
        if self.model_config.is_multimodal_model:
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
            max_prompt_len = self.model_config.max_model_len

            if len(prompt_ids) > max_prompt_len:
                raise ValueError(
                    f"The prompt (total length {len(prompt_ids)}) is too long "
                    f"to fit into the model (context length {max_prompt_len}). "
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
2029
2030
2031
2032

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
        """Constructs logits processors based on the guided_decoding,
        logits_bias, and allowed_token_ids fields in sampling_params. Deletes
        those fields and adds the constructed logits processors to the
        logits_processors field. Returns the modified sampling params."""

        logits_processors = []
2043

2044
2045
2046
2047
2048
        if sampling_params.guided_decoding is not None:
            # Defensively copy sampling params since guided decoding logits
            # processors can have different state for each request
            sampling_params = copy.copy(sampling_params)
            guided_decoding = sampling_params.guided_decoding
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058

            logger.debug(
                "Building guided decoding logits processor in "
                "LLMEngine. Params: %s", guided_decoding)

            tokenizer = self.get_tokenizer(lora_request=lora_request)
            guided_decoding.backend = guided_decoding.backend or \
                self.decoding_config.guided_decoding_backend

            processor = get_local_guided_decoding_logits_processor(
2059
2060
2061
                guided_params=guided_decoding,
                tokenizer=tokenizer,
                model_config=self.model_config)
2062
2063
2064
2065
2066
2067
2068
2069
2070
            if processor:
                logits_processors.append(processor)

            # Unset so this doesn't get passed down to the model
            sampling_params.guided_decoding = None

        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

2071
            processors = get_openai_logits_processors(
2072
2073
2074
2075
2076
2077
2078
2079
2080
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

2081
2082
2083
2084
2085
2086
        if len(sampling_params.bad_words) > 0:
            tokenizer = self.get_tokenizer(lora_request)
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

2087
2088
2089
2090
2091
2092
2093
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params