llm_engine.py 95.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
import copy
Antoni Baum's avatar
Antoni Baum committed
5
import time
6
7
import threading
import queue
8
from collections import Counter as collectionsCounter
9
from collections import deque
10
from contextlib import contextmanager
11
from dataclasses import dataclass
12
from functools import partial
13
14
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
                    List, Mapping, NamedTuple, Optional)
15
from typing import Sequence as GenericSequence
16
from typing import Set, Type, Union, cast, overload
17

18
import torch
19
from typing_extensions import TypeVar, deprecated
20

21
import vllm.envs as envs
22
23
24
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ObservabilityConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
25
26
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.engine.arg_utils import EngineArgs
28
from vllm.engine.metrics_types import StatLoggerBase, Stats
29
30
31
32
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
33
34
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
35
from vllm.executor.executor_base import ExecutorBase
36
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
37
                         PromptType, SingletonInputsAdapter)
38
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
39
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.logger import init_logger
41
from vllm.logits_process import get_bad_words_logits_processors
42
from vllm.lora.request import LoRARequest
43
44
from vllm.model_executor.guided_decoding import (
    get_local_guided_decoding_logits_processor)
45
from vllm.model_executor.layers.sampler import SamplerOutput
46
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
47
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
48
49
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
50
from vllm.prompt_adapter.request import PromptAdapterRequest
51
from vllm.sampling_params import RequestOutputKind, SamplingParams
52
53
54
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
                           PoolingSequenceGroupOutput, Sequence, SequenceGroup,
                           SequenceGroupBase, SequenceGroupMetadata,
zhuwenwen's avatar
zhuwenwen committed
55
                           SequenceGroupOutput, SequenceStatus, CompletionSequenceGroupOutput, VLLM_INVALID_TOKEN_ID)
56
57
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
58
from vllm.transformers_utils.detokenizer import Detokenizer
59
from vllm.transformers_utils.tokenizer import AnyTokenizer
60
from vllm.transformers_utils.tokenizer_group import (
61
    BaseTokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
62
63
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
64
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
65
from vllm.version import __version__ as VLLM_VERSION
lizhigong's avatar
lizhigong committed
66
from vllm.profiler.prof import profile
67
68

logger = init_logger(__name__)
69
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
70

71
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
72
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
73
74


75
76
77
78
79
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
80
81
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
82
83


84
85
86
87
88
89
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
90
91
92
93
94
95
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
96
97
98
    skip: List[int]


99
class SchedulerContext:
100

101
    def __init__(self, multi_step_stream_outputs: bool = False):
102
103
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
104
                                         PoolingRequestOutput]] = []
105
106
107
108
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

109
110
        self.multi_step_stream_outputs: bool = multi_step_stream_outputs

111
112
113
    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
114
115
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
116
117
118
119
120
121
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
122
                       is_first_step_output=is_first_step_output,
123
                       skip=[]))
124
125


126
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
127
    """An LLM engine that receives requests and generates texts.
128

Woosuk Kwon's avatar
Woosuk Kwon committed
129
    This is the main class for the vLLM engine. It receives requests
130
131
132
133
134
135
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

136
137
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
138

139
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
140
    :ref:`engine-args`)
141
142
143
144
145
146
147

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
148
        device_config: The configuration related to the device.
149
150
151
        lora_config (Optional): The configuration related to serving multi-LoRA.
        speculative_config (Optional): The configuration related to speculative
            decoding.
152
153
        executor_class: The model executor class for managing distributed
            execution.
154
        prompt_adapter_config (Optional): The configuration related to serving
155
            prompt adapters.
156
        log_stats: Whether to log statistics.
157
        usage_context: Specified entry point, used for usage info collection.
158
    """
159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

185
        return cast(_O, output)
186
187
188

    @classmethod
    def validate_outputs(
zhuwenwen's avatar
zhuwenwen committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

207
208
209
210
        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

211
212
    def __init__(
        self,
213
        vllm_config: VllmConfig,
214
        executor_class: Type[ExecutorBase],
215
        log_stats: bool,
yhu422's avatar
yhu422 committed
216
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
217
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
218
        input_registry: InputRegistry = INPUT_REGISTRY,
219
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
220
        use_cached_outputs: bool = False,
221
    ) -> None:
222

223
        self.vllm_config = vllm_config
224
225
226
227
228
229
230
231
232
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
233
        )
234
235
        self.prompt_adapter_config = vllm_config.prompt_adapter_config  # noqa
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
236
237
        )

238
        logger.info(
239
            "Initializing a V0 LLM engine (v%s) with config: %s, "
240
            "use_cached_outputs=%s, ",
241
            VLLM_VERSION,
242
            vllm_config,
243
            use_cached_outputs,
244
        )
245

246
        self.log_stats = log_stats
247
        self.use_cached_outputs = use_cached_outputs
248

249
        if not self.model_config.skip_tokenizer_init:
250
            self.tokenizer = self._init_tokenizer()
251
            self.detokenizer = Detokenizer(self.tokenizer)
252
            tokenizer_group = self.get_tokenizer_group()
253
254
        else:
            self.tokenizer = None
255
            self.detokenizer = None
256
257
258
259
260
261
262
263
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
264

265
        self.seq_counter = Counter()
266
267
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
268

269
        self.input_preprocessor = InputPreprocessor(self.model_config,
270
271
                                                    self.tokenizer,
                                                    mm_registry)
272

273
274
        self.input_registry = input_registry
        self.input_processor = input_registry.create_input_processor(
275
            self.model_config)
276

277
        self.model_executor = executor_class(vllm_config=vllm_config, )
278

279
        if self.model_config.runner_type != "pooling":
280
            self._initialize_kv_caches()
281

yhu422's avatar
yhu422 committed
282
283
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
284
285
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
286
            usage_message.report_usage(
287
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
288
289
290
291
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
292
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
293
                    "tensor_parallel_size":
294
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
295
                    "block_size":
296
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
297
                    "gpu_memory_utilization":
298
                    self.cache_config.gpu_memory_utilization,
yhu422's avatar
yhu422 committed
299
300
301

                    # Quantization
                    "quantization":
302
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
303
                    "kv_cache_dtype":
304
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
305
306
307

                    # Feature flags
                    "enable_lora":
308
                    bool(self.lora_config),
309
                    "enable_prompt_adapter":
310
                    bool(self.prompt_adapter_config),
yhu422's avatar
yhu422 committed
311
                    "enable_prefix_caching":
312
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
313
                    "enforce_eager":
314
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
315
                    "disable_custom_all_reduce":
316
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
317
                })
zhuwenwen's avatar
zhuwenwen committed
318

319
320
321
322
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
323

324
325
326
327
328
329
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
330
331
            SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
                             multi_step_stream_outputs)
332
333
334
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

335
        if self.model_config.use_async_output_proc:
336
            process_model_outputs = weak_bind(self._process_model_outputs)
zhuwenwen's avatar
zhuwenwen committed
337
338

            self.async_callbacks = [
339
340
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
zhuwenwen's avatar
zhuwenwen committed
341
342
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
343
344
        else:
            self.async_callbacks = []
345
346
347

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
348
        self.process_request_outputs_callback: Optional[Callable] = None
349

350
        # Create the scheduler.
351
352
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
353
        self.scheduler = [
354
            Scheduler(
355
356
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
357
                self.async_callbacks[v_id]
358
359
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
360
        ]
zhuwenwen's avatar
zhuwenwen committed
361

362
363
        # Metric Logging.
        if self.log_stats:
364
365
366
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
367
368
369
370
371
372
373
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

374
375
376
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
377
378
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        vllm_config=vllm_config),
379
380
381
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
382
383
                        labels=dict(
                            model_name=self.model_config.served_model_name),
384
                        vllm_config=vllm_config),
385
386
387
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
388

389
390
391
392
393
394
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

395
396
397
398
399
400
401
402
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
403
                get_tokenizer_for_seq,
404
405
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
406
                    get_tokenizer_for_seq,
407
408
                ),
            ))
409
410
        
        self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1'
411

412
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
lizhigong's avatar
lizhigong committed
413
414
        
        self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
415
416
417
418
419
420
421
        if self.zero_overhead:
            self.async_d2h = None
            self.last_record = None
            self.async_event = torch.cuda.Event(enable_timing=False)
            self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
            self.q_recorder = queue.Queue()
            self.q_recorder.put(None) # None is use for first step ignore
422
            self.thread_running = True
423
424
            self.sem_m2s = threading.Semaphore(0) # main to scheduler thread
            self.zero_thread.start()
lizhigong's avatar
lizhigong committed
425
        profile.StartTracer()
426

427
428
429
430
431
432
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
433
        start = time.time()
434
435
436
437
438
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
439
440
441
442
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
443
444
445
446
447
448
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
449
450
451
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
452

453
    @classmethod
454
    def _get_executor_cls(cls,
455
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
456
457
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
458
        # Initialize the cluster and specify the executor class.
459
460
461
462
463
464
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
465
        elif engine_config.parallel_config.world_size > 1:
466
            if distributed_executor_backend == "ray":
467
468
469
                from vllm.executor.ray_distributed_executor import (
                    RayDistributedExecutor)
                executor_class = RayDistributedExecutor
470
            elif distributed_executor_backend == "mp":
471
472
473
474
475
476
477
478
479
480
                from vllm.executor.mp_distributed_executor import (
                    MultiprocessingDistributedExecutor)
                assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                    "multiprocessing distributed executor backend does not "
                    "support VLLM_USE_RAY_SPMD_WORKER=1")
                executor_class = MultiprocessingDistributedExecutor
            elif distributed_executor_backend == "uni":
                # JAX-style, single-process, multi-device executor.
                from vllm.executor.uniproc_executor import UniProcExecutor
                executor_class = UniProcExecutor
481
482
483
484
485
            elif distributed_executor_backend == "external_launcher":
                # executor with external launcher
                from vllm.executor.uniproc_executor import (  # noqa
                    ExecutorWithExternalLauncher)
                executor_class = ExecutorWithExternalLauncher
486
        else:
487
488
            from vllm.executor.uniproc_executor import UniProcExecutor
            executor_class = UniProcExecutor
489
490
491
492
493
494
495
496
497
498
499
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
500
        engine_config = engine_args.create_engine_config(usage_context)
501
        executor_class = cls._get_executor_cls(engine_config)
502
        # Create the LLM engine.
yhu422's avatar
yhu422 committed
503
        engine = cls(
504
            vllm_config=engine_config,
yhu422's avatar
yhu422 committed
505
506
507
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
508
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
509
        )
510

511
        return engine
512

513
514
515
516
517
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

518
519
520
521
522
523
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

524
    def get_tokenizer_group(
525
526
527
528
529
530
        self,
        group_type: Type[_G] = BaseTokenizerGroup,
    ) -> _G:
        tokenizer_group = self.tokenizer

        if tokenizer_group is None:
531
532
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
533
534
535
536
        if not isinstance(tokenizer_group, group_type):
            raise TypeError("Invalid type of tokenizer group. "
                            f"Expected type: {group_type}, but "
                            f"found type: {type(tokenizer_group)}")
537

538
        return tokenizer_group
539

540
    def get_tokenizer(
541
542
543
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
544
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
545

546
547
548
549
550
    def _init_tokenizer(self) -> BaseTokenizerGroup:
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
            parallel_config=self.parallel_config,
551
            lora_config=self.lora_config)
552

553
554
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
555
        self.cache_config.verify_with_parallel_config(self.parallel_config)
556
557
558
559
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
560
561
562
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
563

564
565
566
    def _add_processed_request(
        self,
        request_id: str,
567
        processed_inputs: ProcessorInputs,
568
569
570
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
571
        prompt_adapter_request: Optional[PromptAdapterRequest],
572
        trace_headers: Optional[Mapping[str, str]] = None,
573
        priority: int = 0,
574
    ) -> Optional[SequenceGroup]:
575
576
577
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
            )
            return None

592
        self._validate_model_inputs(processed_inputs, lora_request)
593
594
595
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
596
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
597

598
599
600
601
602
603
604
605
        if is_encoder_decoder_inputs(processed_inputs):
            decoder_inputs = processed_inputs["decoder"]
            encoder_inputs = processed_inputs["encoder"]
        else:
            decoder_inputs = processed_inputs
            encoder_inputs = None

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
606
                       lora_request, prompt_adapter_request)
607

608
609
610
        encoder_seq = (None if encoder_inputs is None else Sequence(
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
            prompt_adapter_request))
611

612
613
614
615
616
617
618
619
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
620
                trace_headers=trace_headers,
621
                prompt_adapter_request=prompt_adapter_request,
622
623
                encoder_seq=encoder_seq,
                priority=priority)
624
625
626
627
628
629
630
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
631
                prompt_adapter_request=prompt_adapter_request,
632
633
                encoder_seq=encoder_seq,
                priority=priority)
634
635
636
637
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

638
639
640
641
642
643
644
645
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

646
647
        return seq_group

648
649
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
650

651
    @overload
652
653
654
    def add_request(
        self,
        request_id: str,
655
        prompt: PromptType,
656
        params: Union[SamplingParams, PoolingParams],
657
        arrival_time: Optional[float] = None,
658
        lora_request: Optional[LoRARequest] = None,
659
        trace_headers: Optional[Mapping[str, str]] = None,
660
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
661
        priority: int = 0,
662
    ) -> None:
663
664
665
        ...

    @overload
666
    @deprecated("'inputs' will be renamed to 'prompt")
667
668
669
    def add_request(
        self,
        request_id: str,
670
671
        *,
        inputs: PromptType,
672
        params: Union[SamplingParams, PoolingParams],
673
        arrival_time: Optional[float] = None,
674
        lora_request: Optional[LoRARequest] = None,
675
        trace_headers: Optional[Mapping[str, str]] = None,
676
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
677
        priority: int = 0,
678
    ) -> None:
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    def add_request(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
            priority: int = 0,
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
697
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
698
        """Add a request to the engine's request pool.
699
700

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
701
        scheduler as `engine.step()` is called. The exact scheduling policy is
702
703
704
705
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
706
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
707
708
709
710
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
711
            arrival_time: The arrival time of the request. If None, we use
712
                the current monotonic time.
713
            lora_request: The LoRA request to add.
714
            trace_headers: OpenTelemetry trace headers.
715
            prompt_adapter_request: The prompt adapter request to add.
716
717
            priority: The priority of the request.
                Only applicable with priority scheduling.
718
719
720
721

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
722
            - Create `n` number of :class:`~vllm.Sequence` objects.
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
742
        """
743
744
745
746
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

747
748
749
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
750

751
        if priority != 0 and not self.scheduler_config.policy == "priority":
752
753
754
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

755
756
757
758
759
760
761
        if isinstance(params, SamplingParams) \
            and (params.guided_decoding or params.logits_processors) \
            and self.scheduler_config.num_scheduler_steps > 1:
            raise ValueError(
                "Guided decoding and logits processors are not supported "
                "in multi-step decoding")

762
        if arrival_time is None:
763
            arrival_time = time.time()
764

765
766
767
768
769
        if self.tokenizer is not None:
            self._validate_token_prompt(
                prompt,
                tokenizer=self.get_tokenizer(lora_request=lora_request))

770
        preprocessed_inputs = self.input_preprocessor.preprocess(
771
            prompt,
772
773
            request_id=request_id,
            lora_request=lora_request,
774
775
            prompt_adapter_request=prompt_adapter_request,
        )
776
        processed_inputs = self.input_processor(preprocessed_inputs)
777

778
779
780
781
782
783
        self._add_processed_request(
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
784
            prompt_adapter_request=prompt_adapter_request,
785
            trace_headers=trace_headers,
786
            priority=priority,
787
        )
788

789
790
791
792
793
794
795
796
797
798
799
    def _validate_token_prompt(self, prompt: PromptType,
                               tokenizer: AnyTokenizer):
        # Guard against out-of-vocab tokens.
        # For some tokenizers, tokenizer.decode will happily return empty text
        # for token ids that are out of vocab, and we don't detect token ids
        # that are greater than the max token id before running the model.
        # However, these token ids will later crash a cuda kernel at runtime
        # with an index out of bounds error. This will crash the entire engine.
        # This needs to happen before multimodal input pre-processing, which
        # may add dummy <image> tokens that aren't part of the tokenizer's
        # vocabulary.
800
        if is_token_prompt(prompt):
801
802
803
804
805
806
807
808
809
            prompt_ids = prompt["prompt_token_ids"]
            if len(prompt_ids) == 0:
                # Empty prompt check is handled later
                return
            max_input_id = max(prompt_ids)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    "Token id {} is out of vocabulary".format(max_input_id))

810
811
812
813
814
    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
815
816
        arrival_time: float,
        lora_request: Optional[LoRARequest],
817
        trace_headers: Optional[Mapping[str, str]] = None,
818
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
819
        encoder_seq: Optional[Sequence] = None,
820
        priority: int = 0,
821
822
823
824
825
826
827
828
829
830
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

831
832
833
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

834
835
836
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
837

838
        sampling_params.update_from_generation_config(
839
            self.generation_config_fields, seq.eos_token_id)
840

841
        # Create the sequence group.
842
843
844
845
846
847
848
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
849
            prompt_adapter_request=prompt_adapter_request,
850
851
            encoder_seq=encoder_seq,
            priority=priority)
852

853
854
855
856
857
858
859
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
860
861
        arrival_time: float,
        lora_request: Optional[LoRARequest],
862
        prompt_adapter_request: Optional[PromptAdapterRequest],
863
        encoder_seq: Optional[Sequence] = None,
864
        priority: int = 0,
865
866
867
868
869
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
870
871
872
873
874
875
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
876
            prompt_adapter_request=prompt_adapter_request,
877
878
            encoder_seq=encoder_seq,
            priority=priority)
879
        return seq_group
880

Antoni Baum's avatar
Antoni Baum committed
881
882
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
883
884

        Args:
Antoni Baum's avatar
Antoni Baum committed
885
            request_id: The ID(s) of the request to abort.
886
887
888
889
890
891
892
893
894
895
896

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
897
        """
898
899
        for scheduler in self.scheduler:
            scheduler.abort_seq_group(request_id)
900

901
902
903
904
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

905
906
907
908
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

909
910
911
912
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

913
914
915
916
917
918
919
920
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

921
    def get_num_unfinished_requests(self) -> int:
922
        """Gets the number of unfinished requests."""
923
924
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
925

926
    def has_unfinished_requests(self) -> bool:
927
        """Returns True if there are unfinished requests."""
928
929
930
931
932
933
934
935
936
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
937

938
939
940
941
942
943
944
945
    def reset_prefix_cache(self) -> bool:
        """Reset prefix cache for all devices."""

        success = True
        for scheduler in self.scheduler:
            success = success and scheduler.reset_prefix_cache()
        return success

946
    @staticmethod
947
948
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
949
        outputs: List[PoolingSequenceGroupOutput],
950
    ) -> None:
951
        seq_group.pooled_data = outputs[0].data
952
953
954
955
956
957

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

958
959
960
961
962
963
964
965
    def _update_num_computed_tokens_for_multi_step_prefill(
            self, seq_group: SequenceGroup,
            seq_group_meta: SequenceGroupMetadata,
            is_first_step_output: Optional[bool]):
        """
        This function updates num_computed_tokens for prompt sequences
        when Multi-Step is enabled.

966
        seq_group: SequenceGroup to update the num_computed_tokens for.
967
        seq_group_meta: Metadata of the given SequenceGroup.
968
        is_first_step_output: Optional[bool] -
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
            When available, is_first_step_output indicates if the appended
            output token is the output of the first-step in multi-step.
            A value of None indicates that outputs from all steps in
            in multi-step are submitted in a single burst.
        """

        assert self.scheduler_config.is_multi_step

        if not seq_group_meta.is_prompt:
            # num_computed_token updates for multi-step decodes happen after
            # the tokens are appended to the sequence.
            return

        do_update: bool = False
        if self.scheduler_config.chunked_prefill_enabled:
            # In multi-step + chunked-prefill case, the prompt sequences
            # that are scheduled are fully processed in the first step.
            do_update = is_first_step_output is None or is_first_step_output
        else:
            # Normal multi-step decoding case. In this case prompt-sequences
            # are actually single-stepped. Always update in this case.
            assert seq_group.state.num_steps == 1
            do_update = True

        if do_update:
            seq_group.update_num_computed_tokens(
                seq_group_meta.token_chunk_size)

997
998
999
1000
1001
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
1002

1003
1004
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
1005
        """
1006

1007
        now = time.time()
1008

1009
        if len(ctx.output_queue) == 0:
1010
1011
            return None

1012
        # Get pending async postprocessor
1013
1014
1015
1016
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1017
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
1018
1019
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1020
1021
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
1022
1023
1024
1025
1026

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

1027
        has_multiple_outputs: bool = len(outputs) > 1
1028
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
1029
1030
1031
1032
1033
        if has_multiple_outputs:
            assert self.scheduler_config.is_multi_step or \
                     self.speculative_config
            # Organize outputs by [step][sequence group] instead of
            # [sequence group][step].
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
            if self.scheduler_config.is_multi_step:
                outputs_by_sequence_group = create_output_by_sequence_group(
                    outputs, len(seq_group_metadata_list))
            elif self.speculative_config:
                # Decodes are multi-steps while prefills are not, outputting at
                # most 1 token. Separate them so that we can trigger chunk
                # processing without having to pad or copy over prompts K times
                # to match decodes structure (costly with prompt_logprobs).
                num_prefills = sum(sg.is_prompt
                                   for sg in seq_group_metadata_list)
                prefills, decodes = outputs[:num_prefills], outputs[
                    num_prefills:]
                outputs_by_sequence_group = create_output_by_sequence_group(
                    decodes,
                    num_seq_groups=len(seq_group_metadata_list) - num_prefills)
                outputs_by_sequence_group = [p.outputs for p in prefills
                                             ] + outputs_by_sequence_group
1051
1052
1053
            # We have outputs for multiple steps submitted in a single burst,
            # so invalidate is_first_step_output.
            is_first_step_output = None
1054
        elif len(outputs) == 1:
1055
            outputs_by_sequence_group = outputs
1056
1057
        else:
            return None
1058

1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

1076
        finished_before: List[int] = []
1077
        finished_now: List[int] = []
1078
        empty_seq_indices: List[int] = []
1079
1080
1081
1082
1083
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
1084
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1085

1086
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
1087
1088
1089
1090
1091

            if seq_group.is_finished():
                finished_before.append(i)
                continue

1092
            output: List[SequenceGroupOutput]
1093
            if has_multiple_outputs:
1094
1095
1096
1097
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

1098
            # tree style speculative decoding may generate empty output in first step
1099
1100
1101
1102
1103
1104
1105
1106
1107
            if self.tree_decoding and outputs and isinstance(output[0], CompletionSequenceGroupOutput):
                samples = [o.samples[0] for o in output]
                valid_samples = [
                    sample for sample in samples
                    if sample.output_token != VLLM_INVALID_TOKEN_ID
                ]
                if len(valid_samples) == 0:
                    empty_seq_indices.append(i)
                    continue
1108

1109
            if not is_async:
1110
1111
1112
1113
1114
1115
                if self.scheduler_config.is_multi_step:
                    # Updates happen only if the sequence is prefill
                    self._update_num_computed_tokens_for_multi_step_prefill(
                        seq_group, seq_group_meta, is_first_step_output)
                else:
                    seq_group.update_num_computed_tokens(
1116
                        seq_group_meta.token_chunk_size or 0)
1117
1118
1119

            if outputs:
                for o in outputs:
1120
1121
1122
1123
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
1124
                                o.model_forward_time or 0)
1125
1126
1127
1128
1129
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
1130
                                o.model_execute_time or 0)
1131
1132
1133
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1134

1135
            if self.model_config.runner_type == "pooling":
1136
                self._process_sequence_group_outputs(seq_group, output)
1137
1138
1139
1140
1141
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
                    self.output_processor.process_outputs(
                        seq_group, output, is_async)
1142

1143
1144
            if seq_group.is_finished():
                finished_now.append(i)
1145

1146
1147
1148
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1149

1150
            seq_group = scheduled_seq_group.seq_group
1151
            seq_group.maybe_set_first_token_time(now)
1152
1153
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1154
            request_output = RequestOutputFactory.create(
1155
1156
1157
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1158
1159
            if request_output:
                ctx.request_outputs.append(request_output)
1160

1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1173
1174
1175
1176
1177
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

1178
1179
        # For multi-step without streaming, don't create outputs each iteration
        if not is_last_step and not ctx.multi_step_stream_outputs:
1180
1181
1182
1183
            # Immediately process request outputs here (if callback is given)
            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
1184
                ctx.request_outputs.clear()
1185
1186
1187
            return

        # Create the outputs
1188
        for i in indices:
1189
            if i in skip or i in finished_before or i in finished_now or i in empty_seq_indices:
1190
1191
                continue  # Avoids double processing

1192
1193
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1194
            seq_group = scheduled_seq_group.seq_group
1195
            seq_group.maybe_set_first_token_time(now)
1196
1197
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1198
            request_output = RequestOutputFactory.create(
1199
1200
1201
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1202
            if request_output:
1203
                ctx.request_outputs.append(request_output)
1204

1205
1206
1207
1208
1209
1210
1211
1212
        # For multi-step with streaming, create outputs each iteration
        if not is_last_step and ctx.multi_step_stream_outputs:
            # Immediately process request outputs here (if callback is given)
            if self.process_request_outputs_callback is not None:
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1213
        for seq_group in scheduler_outputs.ignored_seq_groups:
1214
1215
1216
1217
1218
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1219
            request_output = RequestOutputFactory.create(
1220
1221
1222
1223
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
1224
1225
            if request_output:
                ctx.request_outputs.append(request_output)
1226

1227
1228
1229
1230
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1231
            ctx.request_outputs.clear()
1232

1233
1234
1235
1236
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1237
            # Log stats.
1238
1239
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1240
1241

            # Tracing
1242
            self.do_tracing(scheduler_outputs, finished_before)
1243
1244
1245

        return None

lizhigong's avatar
lizhigong committed
1246
    def _fix_last_step(
1247
            self, output: List[SamplerOutput],
lizhigong's avatar
lizhigong committed
1248
1249
1250
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        
lizhigong's avatar
lizhigong committed
1251
1252
1253
1254
1255
        #sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
        sample_out_list = self.async_d2h.tolist()
        sample_out_ids = output[0].sampler_out_ids.tolist()
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output[0], scheduled_seq_groups):
lizhigong's avatar
lizhigong committed
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

            if seq_group_metadata.do_sample:
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
lizhigong's avatar
lizhigong committed
1266
1267
1268
1269
1270
                for token_id, seq_id in zip(sample_out_list, sample_out_ids):
                    if seq.seq_id == seq_id:
                        sample.output_token = token_id[0]
                        seq.fix_last_token_id(sample.output_token)
                        break
lizhigong's avatar
lizhigong committed
1271

1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
    def _advance_to_next_step(
            self, output: List[SamplerOutput],
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1287
1288
1289
1290
1291
            if self.scheduler_config.is_multi_step:
                # Updates happen only if the sequence is prefill
                self._update_num_computed_tokens_for_multi_step_prefill(
                    seq_group, seq_group_metadata,
                    seq_group.state.num_steps == 1)
1292
            else:
1293
1294
1295
1296
                token_chunk_size = (seq_group_metadata.token_chunk_size
                                    if seq_group_metadata.token_chunk_size
                                    is not None else 0)
                seq_group.update_num_computed_tokens(token_chunk_size)
1297
1298
1299
1300

            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1301
                    " (i.e sampling_params.n == 1)")
1302
1303
1304
1305
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1306

1307
1308
1309
1310
1311
1312
1313
1314
                if self.scheduler_config.is_multi_step:
                    is_prefill_append = seq.data.get_num_uncomputed_tokens(
                    ) == 0
                    seq.append_token_id(sample.output_token, sample.logprobs)
                    if not is_prefill_append:
                        seq_group.update_num_computed_tokens(1)
                else:
                    seq.append_token_id(sample.output_token, sample.logprobs)
1315

lizhigong's avatar
lizhigong committed
1316
1317
    def trans_last_output_tensor(self, last_output) -> torch.Tensor:
        return None
1318
    
1319
1320
1321
1322
1323
    def finish_thread(self):
        if self.zero_overhead:
            self.thread_running = False
            self.sem_m2s.release()
    
1324
1325
1326
    def thread_zero_overhead(self):
        while True:
            self.sem_m2s.acquire()
1327
1328
            if not self.thread_running:
                break
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
            virtual_engine = 0
            ctx = self.scheduler_contexts[virtual_engine]

            # Clear outputs for each new scheduler iteration
            ctx.request_outputs.clear()
            # Schedule iteration
            (seq_group_metadata_list, scheduler_outputs,
                allow_async_output_proc
                ) = self.scheduler[virtual_engine].schedule()

            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs

            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()

            assert seq_group_metadata_list is not None
            assert scheduler_outputs is not None

            last_outputs_ids = None
            last_outputs_tensor = None
            if self.last_record is not None:
                last_output = self.last_record[0][0]
                last_outputs_ids,  last_outputs_tensor = last_output.sampler_out_ids,  last_output.sampler_out_tenosr
                self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
                self.async_event.record()
                self.q_recorder.put(self.last_record)
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids, 
                last_outputs_ids = last_outputs_ids, 
                last_outputs_sample = last_outputs_tensor)
1372
1373
1374
            # if allow_async_output_proc:
            #     execute_model_req.async_callback = self.async_callbacks[
            #         virtual_engine]
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385

            #profile.ProfRangeAutoPush('model_executor')
            outputs = self.model_executor.execute_model(
                execute_model_req=execute_model_req)

            self._advance_to_next_step(
                outputs[0], seq_group_metadata_list,
                scheduler_outputs.scheduled_seq_groups)
            self.last_record = [outputs, seq_group_metadata_list, scheduler_outputs]

    def zero_overhead_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
1386
1387
1388
1389
        if not self.thread_running:
            self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
            self.thread_running = True
            self.zero_thread.start()
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
        self.sem_m2s.release()
        recode_output = self.q_recorder.get()
        if recode_output is None: # None is for the first step
            return None
        virtual_engine = 0
        ctx = self.scheduler_contexts[virtual_engine]
        outputs, seq_group_metadata_list, scheduler_outputs = recode_output
        ctx.seq_group_metadata_list = seq_group_metadata_list
        ctx.scheduler_outputs = scheduler_outputs
        self.async_event.synchronize()
        self._fix_last_step(
            outputs, seq_group_metadata_list,
            scheduler_outputs.scheduled_seq_groups)

1404
1405
1406
1407
1408
        # is_first_step_output is True only when the num_steps of all
        # the sequences are 1. When the num_steps > 1,
        # multi_step_model_runner does the first-step output append.
        is_first_step_output: bool = False if not seq_group_metadata_list \
            else seq_group_metadata_list[0].state.num_steps == 1
1409

1410
1411
1412
1413
1414
1415
1416
        # Add results to the output_queue
        ctx.append_output(outputs=outputs,
                            seq_group_metadata_list=seq_group_metadata_list,
                            scheduler_outputs=scheduler_outputs,
                            is_async=True,
                            is_last_step=True,
                            is_first_step_output=is_first_step_output)
1417

1418
1419
1420
1421
1422
1423
1424
1425
1426
        # Check if need to run the usual non-async path
        #if not allow_async_output_proc:
        self._process_model_outputs(ctx=ctx)

        # Log stats.
        self.do_log_stats(scheduler_outputs, outputs)

        # Tracing
        self.do_tracing(scheduler_outputs)
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442

        #profile.ProfRangeAutoPush('has_unfinish')
        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
            assert len(ctx.output_queue) == 0

            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
            logger.debug("Stopping remote worker execution loop.")
            self.model_executor.stop_remote_worker_execution_loop()
        return ctx.request_outputs
lizhigong's avatar
lizhigong committed
1443

1444
    def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1445
1446
        """Performs one decoding iteration and returns newly generated results.

1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

1462
            - Step 2: Calls the distributed executor to execute the model.
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
1484
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
1485
1486
1487
1488
1489
1490
1491
1492
1493
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
1494
        """
1495
1496
1497
        if self.zero_overhead:
            return self.zero_overhead_step()

1498
1499
1500
1501
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1502

1503
        # For llm_engine, there is no pipeline parallel support, so the engine
1504
        # used is always 0.
1505
1506
        virtual_engine = 0

1507
1508
        # These are cached outputs from previous iterations. None if on first
        # iteration
1509
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1510
1511
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1512
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1513

1514
1515
        ctx = self.scheduler_contexts[virtual_engine]

1516
1517
1518
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1519
1520
1521
1522
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
1523
            # Schedule iteration
1524
            (seq_group_metadata_list, scheduler_outputs,
1525
1526
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1527

1528
1529
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1530

1531
1532
1533
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()

1534
1535
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1536
                self._process_model_outputs(ctx=ctx)
1537

1538
1539
1540
1541
1542
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
1543
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
1544
                    allow_async_output_proc)
1545
1546
        else:
            finished_requests_ids = list()
1547
1548
1549

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1550

1551
        if not scheduler_outputs.is_empty():
1552
1553
1554
1555
1556
1557

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1558
                self._get_last_sampled_token_ids(virtual_engine)
1559

1560
            execute_model_req = ExecuteModelRequest(
1561
1562
1563
1564
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1565
1566
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1567
1568
1569
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
1570
                last_sampled_token_ids=last_sampled_token_ids)
1571
            if allow_async_output_proc:
1572
                execute_model_req.async_callback = self.async_callbacks[
1573
                    virtual_engine]
1574

1575
            #profile.ProfRangeAutoPush('model_executor')
1576
            outputs = self.model_executor.execute_model(
1577
                execute_model_req=execute_model_req)
1578
            #profile.ProfRangeAutoPush('end_executor')
1579
            # We need to do this here so that last step's sampled_token_ids can
1580
1581
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
1582
                self._update_cached_scheduler_output(virtual_engine, outputs)
1583
        else:
1584
1585
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1586
1587
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1588
            # No outputs in this case
1589
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1590

lizhigong's avatar
lizhigong committed
1591
1592
1593
1594
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()
lizhigong's avatar
lizhigong committed
1595

1596
        if not self._has_remaining_steps(seq_group_metadata_list):
1597
            # clear the cache if we have finished all the steps.
1598
1599
1600
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()

1601
1602
1603
1604
1605
1606
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1607
            # Add results to the output_queue
1608
1609
1610
1611
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1612
1613
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1614
1615
1616

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1617
                    "Async postprocessor expects only a single output set")
1618
1619
1620
                self._advance_to_next_step(
                    outputs[0], seq_group_metadata_list,
                    scheduler_outputs.scheduled_seq_groups)
1621

1622
            # Check if need to run the usual non-async path
1623
            if not allow_async_output_proc:
1624
                self._process_model_outputs(ctx=ctx)
1625

1626
                # Log stats.
1627
                self.do_log_stats(scheduler_outputs, outputs)
1628

1629
1630
1631
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1632
            # Multi-step case
1633
            return ctx.request_outputs
1634

1635
        #profile.ProfRangeAutoPush('has_unfinish')
1636
        if not self.has_unfinished_requests():
1637
1638
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1639
                self._process_model_outputs(ctx=ctx)
1640
            assert len(ctx.output_queue) == 0
1641

1642
1643
1644
1645
1646
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1647
            logger.debug("Stopping remote worker execution loop.")
1648
            self.model_executor.stop_remote_worker_execution_loop()
1649
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1650

1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
1666
1667
            raise AssertionError("All running sequence groups should "
                                 "have the same remaining steps.")
1668
1669
1670
1671
1672
1673

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1674
1675
1676
1677
1678
1679
1680
1681
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None
Antoni Baum's avatar
Antoni Baum committed
1706

1707
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1708
1709
1710
1711
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1712
1713
1714
1715
1716
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1717
1718
1719
1720
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1721
1722
1723
1724
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1725
1726
1727
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1728
1729
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1730
1731
        """Forced log when no requests active."""
        if self.log_stats:
1732
            stats = self._get_stats(scheduler_outputs, model_output,
1733
                                    finished_before, skip)
1734
            for logger in self.stat_loggers.values():
1735
                logger.log(stats)
1736

1737
1738
1739
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1740
1741
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1742
1743
1744
1745
1746
1747
1748
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1749
1750
1751
1752
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1753
        """
1754
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1755

1756
1757
        # System State
        #   Scheduler State
1758
1759
1760
1761
1762
1763
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1764
1765

        # KV Cache Usage in %
1766
        num_total_gpu = self.cache_config.num_gpu_blocks
1767
        gpu_cache_usage_sys = 0.
1768
        if num_total_gpu:  # Guard against both None and 0
1769
1770
1771
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1772
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1773

1774
        num_total_cpu = self.cache_config.num_cpu_blocks
1775
        cpu_cache_usage_sys = 0.
1776
        if num_total_cpu:  # Guard against both None and 0
1777
1778
1779
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1780
1781
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1782
1783
1784
1785
1786
1787
1788
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1789
1790
1791
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1792
        num_tokens_iter = 0
1793
1794
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1795
1796
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1797
1798
1799
1800

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1801
1802
1803
1804
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1805
1806
1807
        time_in_queue_requests: List[float] = []
        model_forward_time_requests: List[float] = []
        model_execute_time_requests: List[float] = []
1808
1809
1810
1811
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1812
        max_num_generation_tokens_requests: List[int] = []
1813
        max_tokens_requests: List[int] = []
1814
1815
        finished_reason_requests: List[str] = []

1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
        # Lora requests
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1835
1836
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1837
        if scheduler_outputs is not None:
1838
1839
1840
1841
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1842
            num_generation_tokens_from_prefill_groups = 0
1843
1844
1845
1846
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1847
1848
1849

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1850
1851
1852
1853
1854
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue

1855
1856
1857
1858
                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1859

1860
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1861
                seq_group = scheduled_seq_group.seq_group
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
1874
                        latency = seq_group.get_last_token_latency()
1875
1876
1877
1878
1879
1880
1881
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
1882
                    latency = seq_group.get_last_token_latency()
1883
                    time_per_output_tokens_iter.append(latency)
1884
1885
1886
1887
1888
1889
1890
1891
1892
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1893
1894
1895
1896
1897
1898

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1899
                if seq_group.is_finished():
1900
                    # Latency timings
1901
1902
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1915
1916
1917
1918
1919
1920
1921
1922
1923
                    if seq_group.metrics.time_in_queue is not None:
                        time_in_queue_requests.append(
                            seq_group.metrics.time_in_queue)
                    if seq_group.metrics.model_forward_time is not None:
                        model_forward_time_requests.append(
                            seq_group.metrics.model_forward_time)
                    if seq_group.metrics.model_execute_time is not None:
                        model_execute_time_requests.append(
                            seq_group.metrics.model_execute_time * 1000)
1924
1925
1926
1927
1928
1929
1930
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1931
1932
1933
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1934
1935
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1936
1937
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1950
                actual_num_batched_tokens - num_prompt_tokens_iter +
1951
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1952
1953
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1954
1955
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
1956
1957
        if model_output and isinstance(model_output[0], SamplerOutput) and (
                model_output[0].spec_decode_worker_metrics is not None):
1958
1959
1960
1961
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1962
1963
        return Stats(
            now=now,
1964
1965
1966
1967
1968
1969
1970
1971
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1972
1973
1974
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1975
1976
1977
1978

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1979
            num_tokens_iter=num_tokens_iter,
1980
1981
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1982
            spec_decode_metrics=spec_decode_metrics,
1983
            num_preemption_iter=num_preemption_iter,
1984
1985
1986
1987

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1988
1989
1990
1991
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1992
1993
1994
            time_in_queue_requests=time_in_queue_requests,
            model_forward_time_requests=model_forward_time_requests,
            model_execute_time_requests=model_execute_time_requests,
1995
1996
1997
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1998
1999
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
2000
            n_requests=n_requests,
2001
            max_tokens_requests=max_tokens_requests,
2002
            finished_reason_requests=finished_reason_requests,
2003
2004
2005
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
2006

2007
    def add_lora(self, lora_request: LoRARequest) -> bool:
2008
        return self.model_executor.add_lora(lora_request)
2009
2010

    def remove_lora(self, lora_id: int) -> bool:
2011
        return self.model_executor.remove_lora(lora_id)
2012

2013
    def list_loras(self) -> Set[int]:
2014
        return self.model_executor.list_loras()
2015

2016
2017
2018
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

2029
2030
2031
2032
2033
2034
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
    def sleep(self, level: int = 1) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.sleep(level=level)

    def wake_up(self) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.wake_up()

2045
    def check_health(self) -> None:
2046
2047
        if self.tokenizer:
            self.tokenizer.check_health()
2048
        self.model_executor.check_health()
2049
2050
2051
2052

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

2053
2054
2055
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
2056
2057
2058
        if self.tracer is None:
            return

2059
2060
2061
2062
2063
2064
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
2084
            seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
2085
                                   self.model_config.model)
2086
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
2087
                                   seq_group.request_id)
2088
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
2089
                                   seq_group.sampling_params.temperature)
2090
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
2091
                                   seq_group.sampling_params.top_p)
2092
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
2093
                                   seq_group.sampling_params.max_tokens)
2094
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
2095
                                   seq_group.sampling_params.n)
2096
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
2097
                                   seq_group.num_seqs())
2098
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
2099
2100
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
2101
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
2102
2103
2104
2105
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
2106
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
2107
2108
                                   metrics.time_in_queue)
            seq_span.set_attribute(
2109
2110
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
2111
2112
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
2113
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
2114
2115
2116
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
2117
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
2118
2119
2120
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
2121
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
2122
                    metrics.model_execute_time)
2123

2124
    def _validate_model_inputs(self, inputs: ProcessorInputs,
2125
                               lora_request: Optional[LoRARequest]):
2126
        if is_encoder_decoder_inputs(inputs):
2127
2128
            # For encoder-decoder multimodal models, the max_prompt_len
            # restricts the decoder prompt length
2129
2130
            prompt_inputs = inputs["decoder" if self.model_config.
                                   is_multimodal_model else "encoder"]
2131
        else:
2132
2133
            prompt_inputs = inputs

2134
        prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
2135
2136

        if prompt_ids is None or len(prompt_ids) == 0:
2137
            raise ValueError("Prompt cannot be empty")
2138

2139
        if self.model_config.is_multimodal_model:
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
            max_prompt_len = self.model_config.max_model_len

            if len(prompt_ids) > max_prompt_len:
                raise ValueError(
                    f"The prompt (total length {len(prompt_ids)}) is too long "
                    f"to fit into the model (context length {max_prompt_len}). "
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
2150
2151
2152
2153

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
        """Constructs logits processors based on the guided_decoding,
        logits_bias, and allowed_token_ids fields in sampling_params. Deletes
        those fields and adds the constructed logits processors to the
        logits_processors field. Returns the modified sampling params."""

        logits_processors = []
2164

2165
2166
2167
2168
2169
        if sampling_params.guided_decoding is not None:
            # Defensively copy sampling params since guided decoding logits
            # processors can have different state for each request
            sampling_params = copy.copy(sampling_params)
            guided_decoding = sampling_params.guided_decoding
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179

            logger.debug(
                "Building guided decoding logits processor in "
                "LLMEngine. Params: %s", guided_decoding)

            tokenizer = self.get_tokenizer(lora_request=lora_request)
            guided_decoding.backend = guided_decoding.backend or \
                self.decoding_config.guided_decoding_backend

            processor = get_local_guided_decoding_logits_processor(
2180
2181
2182
                guided_params=guided_decoding,
                tokenizer=tokenizer,
                model_config=self.model_config)
2183
2184
2185
2186
2187
2188
2189
2190
2191
            if processor:
                logits_processors.append(processor)

            # Unset so this doesn't get passed down to the model
            sampling_params.guided_decoding = None

        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

2192
            processors = get_openai_logits_processors(
2193
2194
2195
2196
2197
2198
2199
2200
2201
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

2202
2203
2204
2205
2206
2207
        if len(sampling_params.bad_words) > 0:
            tokenizer = self.get_tokenizer(lora_request)
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

2208
2209
2210
2211
2212
2213
2214
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params