llm_engine.py 94.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
import copy
Antoni Baum's avatar
Antoni Baum committed
5
import time
6
7
import threading
import queue
8
from collections import Counter as collectionsCounter
9
from collections import deque
10
from contextlib import contextmanager
11
from dataclasses import dataclass
12
from functools import partial
13
14
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
                    List, Mapping, NamedTuple, Optional)
15
from typing import Sequence as GenericSequence
16
from typing import Set, Type, Union, cast, overload
17

18
import torch
19
from typing_extensions import TypeVar, deprecated
20

21
import vllm.envs as envs
22
23
24
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ObservabilityConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
25
26
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.engine.arg_utils import EngineArgs
28
from vllm.engine.metrics_types import StatLoggerBase, Stats
29
30
31
32
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
33
34
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
35
from vllm.executor.executor_base import ExecutorBase
36
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
37
                         PromptType, SingletonInputsAdapter)
38
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
39
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.logger import init_logger
41
from vllm.logits_process import get_bad_words_logits_processors
42
from vllm.lora.request import LoRARequest
43
44
from vllm.model_executor.guided_decoding import (
    get_local_guided_decoding_logits_processor)
45
from vllm.model_executor.layers.sampler import SamplerOutput
46
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
47
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
48
49
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
50
from vllm.prompt_adapter.request import PromptAdapterRequest
51
from vllm.sampling_params import RequestOutputKind, SamplingParams
52
53
54
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
                           PoolingSequenceGroupOutput, Sequence, SequenceGroup,
                           SequenceGroupBase, SequenceGroupMetadata,
zhuwenwen's avatar
zhuwenwen committed
55
                           SequenceGroupOutput, SequenceStatus, CompletionSequenceGroupOutput, VLLM_INVALID_TOKEN_ID)
56
57
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
58
from vllm.transformers_utils.detokenizer import Detokenizer
59
from vllm.transformers_utils.tokenizer import AnyTokenizer
60
from vllm.transformers_utils.tokenizer_group import (
61
    BaseTokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
62
63
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
64
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
65
from vllm.version import __version__ as VLLM_VERSION
lizhigong's avatar
lizhigong committed
66
from vllm.profiler.prof import profile
67
68

logger = init_logger(__name__)
69
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
70

71
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
72
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
73
74


75
76
77
78
79
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
80
81
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
82
83


84
85
86
87
88
89
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
90
91
92
93
94
95
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
96
97
98
    skip: List[int]


99
class SchedulerContext:
100

101
    def __init__(self, multi_step_stream_outputs: bool = False):
102
103
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
104
                                         PoolingRequestOutput]] = []
105
106
107
108
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

109
110
        self.multi_step_stream_outputs: bool = multi_step_stream_outputs

111
112
113
    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
114
115
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
116
117
118
119
120
121
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
122
                       is_first_step_output=is_first_step_output,
123
                       skip=[]))
124
125


126
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
127
    """An LLM engine that receives requests and generates texts.
128

Woosuk Kwon's avatar
Woosuk Kwon committed
129
    This is the main class for the vLLM engine. It receives requests
130
131
132
133
134
135
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

136
137
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
138

139
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
140
    :ref:`engine-args`)
141
142
143
144
145
146
147

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
148
        device_config: The configuration related to the device.
149
150
151
        lora_config (Optional): The configuration related to serving multi-LoRA.
        speculative_config (Optional): The configuration related to speculative
            decoding.
152
153
        executor_class: The model executor class for managing distributed
            execution.
154
        prompt_adapter_config (Optional): The configuration related to serving
155
            prompt adapters.
156
        log_stats: Whether to log statistics.
157
        usage_context: Specified entry point, used for usage info collection.
158
    """
159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

185
        return cast(_O, output)
186
187
188

    @classmethod
    def validate_outputs(
zhuwenwen's avatar
zhuwenwen committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

207
208
209
210
        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

211
212
    def __init__(
        self,
213
        vllm_config: VllmConfig,
214
        executor_class: Type[ExecutorBase],
215
        log_stats: bool,
yhu422's avatar
yhu422 committed
216
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
217
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
218
        input_registry: InputRegistry = INPUT_REGISTRY,
219
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
220
        use_cached_outputs: bool = False,
221
    ) -> None:
222

223
        self.vllm_config = vllm_config
224
225
226
227
228
229
230
231
232
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
233
        )
234
235
        self.prompt_adapter_config = vllm_config.prompt_adapter_config  # noqa
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
236
237
        )

238
        logger.info(
239
            "Initializing a V0 LLM engine (v%s) with config: %s, "
240
            "use_cached_outputs=%s, ",
241
            VLLM_VERSION,
242
            vllm_config,
243
            use_cached_outputs,
244
        )
245

246
        self.log_stats = log_stats
247
        self.use_cached_outputs = use_cached_outputs
248

249
        if not self.model_config.skip_tokenizer_init:
250
            self.tokenizer = self._init_tokenizer()
251
            self.detokenizer = Detokenizer(self.tokenizer)
252
            tokenizer_group = self.get_tokenizer_group()
253
254
        else:
            self.tokenizer = None
255
            self.detokenizer = None
256
257
258
259
260
261
262
263
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
264

265
        self.seq_counter = Counter()
266
267
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
268

269
        self.input_preprocessor = InputPreprocessor(self.model_config,
270
271
                                                    self.tokenizer,
                                                    mm_registry)
272

273
274
        self.input_registry = input_registry
        self.input_processor = input_registry.create_input_processor(
275
            self.model_config)
276

277
        self.model_executor = executor_class(vllm_config=vllm_config, )
278

279
        if self.model_config.runner_type != "pooling":
280
            self._initialize_kv_caches()
281

yhu422's avatar
yhu422 committed
282
283
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
284
285
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
286
            usage_message.report_usage(
287
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
288
289
290
291
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
292
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
293
                    "tensor_parallel_size":
294
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
295
                    "block_size":
296
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
297
                    "gpu_memory_utilization":
298
                    self.cache_config.gpu_memory_utilization,
yhu422's avatar
yhu422 committed
299
300
301

                    # Quantization
                    "quantization":
302
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
303
                    "kv_cache_dtype":
304
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
305
306
307

                    # Feature flags
                    "enable_lora":
308
                    bool(self.lora_config),
309
                    "enable_prompt_adapter":
310
                    bool(self.prompt_adapter_config),
yhu422's avatar
yhu422 committed
311
                    "enable_prefix_caching":
312
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
313
                    "enforce_eager":
314
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
315
                    "disable_custom_all_reduce":
316
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
317
                })
zhuwenwen's avatar
zhuwenwen committed
318

319
320
321
322
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
323

324
325
326
327
328
329
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
330
331
            SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
                             multi_step_stream_outputs)
332
333
334
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

335
        if self.model_config.use_async_output_proc:
336
            process_model_outputs = weak_bind(self._process_model_outputs)
zhuwenwen's avatar
zhuwenwen committed
337
338

            self.async_callbacks = [
339
340
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
zhuwenwen's avatar
zhuwenwen committed
341
342
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
343
344
        else:
            self.async_callbacks = []
345
346
347

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
348
        self.process_request_outputs_callback: Optional[Callable] = None
349

350
        # Create the scheduler.
351
352
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
353
        self.scheduler = [
354
            Scheduler(
355
356
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
357
                self.async_callbacks[v_id]
358
359
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
360
        ]
zhuwenwen's avatar
zhuwenwen committed
361

362
363
        # Metric Logging.
        if self.log_stats:
364
365
366
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
367
368
369
370
371
372
373
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

374
375
376
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
377
378
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        vllm_config=vllm_config),
379
380
381
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
382
383
                        labels=dict(
                            model_name=self.model_config.served_model_name),
384
                        vllm_config=vllm_config),
385
386
387
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
388

389
390
391
392
393
394
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

395
396
397
398
399
400
401
402
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
403
                get_tokenizer_for_seq,
404
405
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
406
                    get_tokenizer_for_seq,
407
408
                ),
            ))
409
410
        
        self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1'
411

412
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
lizhigong's avatar
lizhigong committed
413
414
        
        self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
415
416
417
418
419
420
421
422
423
424
425
        if self.zero_overhead:
            # self.step_switch = 0 # 0 step A 1 step B
            # self.output_recorder = [None, None]
            self.async_d2h = None
            self.last_record = None
            self.async_event = torch.cuda.Event(enable_timing=False)
            self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
            self.q_recorder = queue.Queue()
            self.q_recorder.put(None) # None is use for first step ignore
            self.sem_m2s = threading.Semaphore(0) # main to scheduler thread
            self.zero_thread.start()
lizhigong's avatar
lizhigong committed
426
        profile.StartTracer()
427

428
429
430
431
432
433
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
434
        start = time.time()
435
436
437
438
439
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
440
441
442
443
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
444
445
446
447
448
449
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
450
451
452
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
453

454
    @classmethod
455
    def _get_executor_cls(cls,
456
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
457
458
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
459
        # Initialize the cluster and specify the executor class.
460
461
462
463
464
465
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
466
        elif engine_config.parallel_config.world_size > 1:
467
            if distributed_executor_backend == "ray":
468
469
470
                from vllm.executor.ray_distributed_executor import (
                    RayDistributedExecutor)
                executor_class = RayDistributedExecutor
471
            elif distributed_executor_backend == "mp":
472
473
474
475
476
477
478
479
480
481
                from vllm.executor.mp_distributed_executor import (
                    MultiprocessingDistributedExecutor)
                assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                    "multiprocessing distributed executor backend does not "
                    "support VLLM_USE_RAY_SPMD_WORKER=1")
                executor_class = MultiprocessingDistributedExecutor
            elif distributed_executor_backend == "uni":
                # JAX-style, single-process, multi-device executor.
                from vllm.executor.uniproc_executor import UniProcExecutor
                executor_class = UniProcExecutor
482
483
484
485
486
            elif distributed_executor_backend == "external_launcher":
                # executor with external launcher
                from vllm.executor.uniproc_executor import (  # noqa
                    ExecutorWithExternalLauncher)
                executor_class = ExecutorWithExternalLauncher
487
        else:
488
489
            from vllm.executor.uniproc_executor import UniProcExecutor
            executor_class = UniProcExecutor
490
491
492
493
494
495
496
497
498
499
500
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
501
        engine_config = engine_args.create_engine_config(usage_context)
502
        executor_class = cls._get_executor_cls(engine_config)
503
        # Create the LLM engine.
yhu422's avatar
yhu422 committed
504
        engine = cls(
505
            vllm_config=engine_config,
yhu422's avatar
yhu422 committed
506
507
508
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
509
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
510
        )
511

512
        return engine
513

514
515
516
517
518
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

519
520
521
522
523
524
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

525
    def get_tokenizer_group(
526
527
528
529
530
531
        self,
        group_type: Type[_G] = BaseTokenizerGroup,
    ) -> _G:
        tokenizer_group = self.tokenizer

        if tokenizer_group is None:
532
533
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
534
535
536
537
        if not isinstance(tokenizer_group, group_type):
            raise TypeError("Invalid type of tokenizer group. "
                            f"Expected type: {group_type}, but "
                            f"found type: {type(tokenizer_group)}")
538

539
        return tokenizer_group
540

541
    def get_tokenizer(
542
543
544
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
545
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
546

547
548
549
550
551
    def _init_tokenizer(self) -> BaseTokenizerGroup:
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
            parallel_config=self.parallel_config,
552
            lora_config=self.lora_config)
553

554
555
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
556
        self.cache_config.verify_with_parallel_config(self.parallel_config)
557
558
559
560
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
561
562
563
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
564

565
566
567
    def _add_processed_request(
        self,
        request_id: str,
568
        processed_inputs: ProcessorInputs,
569
570
571
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
572
        prompt_adapter_request: Optional[PromptAdapterRequest],
573
        trace_headers: Optional[Mapping[str, str]] = None,
574
        priority: int = 0,
575
    ) -> Optional[SequenceGroup]:
576
577
578
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
            )
            return None

593
        self._validate_model_inputs(processed_inputs, lora_request)
594
595
596
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
597
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
598

599
600
601
602
603
604
605
606
        if is_encoder_decoder_inputs(processed_inputs):
            decoder_inputs = processed_inputs["decoder"]
            encoder_inputs = processed_inputs["encoder"]
        else:
            decoder_inputs = processed_inputs
            encoder_inputs = None

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
607
                       lora_request, prompt_adapter_request)
608

609
610
611
        encoder_seq = (None if encoder_inputs is None else Sequence(
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
            prompt_adapter_request))
612

613
614
615
616
617
618
619
620
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
621
                trace_headers=trace_headers,
622
                prompt_adapter_request=prompt_adapter_request,
623
624
                encoder_seq=encoder_seq,
                priority=priority)
625
626
627
628
629
630
631
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
632
                prompt_adapter_request=prompt_adapter_request,
633
634
                encoder_seq=encoder_seq,
                priority=priority)
635
636
637
638
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

639
640
641
642
643
644
645
646
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

647
648
        return seq_group

649
650
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
651

652
    @overload
653
654
655
    def add_request(
        self,
        request_id: str,
656
        prompt: PromptType,
657
        params: Union[SamplingParams, PoolingParams],
658
        arrival_time: Optional[float] = None,
659
        lora_request: Optional[LoRARequest] = None,
660
        trace_headers: Optional[Mapping[str, str]] = None,
661
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
662
        priority: int = 0,
663
    ) -> None:
664
665
666
        ...

    @overload
667
    @deprecated("'inputs' will be renamed to 'prompt")
668
669
670
    def add_request(
        self,
        request_id: str,
671
672
        *,
        inputs: PromptType,
673
        params: Union[SamplingParams, PoolingParams],
674
        arrival_time: Optional[float] = None,
675
        lora_request: Optional[LoRARequest] = None,
676
        trace_headers: Optional[Mapping[str, str]] = None,
677
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
678
        priority: int = 0,
679
    ) -> None:
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    def add_request(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
            priority: int = 0,
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
698
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
699
        """Add a request to the engine's request pool.
700
701

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
702
        scheduler as `engine.step()` is called. The exact scheduling policy is
703
704
705
706
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
707
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
708
709
710
711
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
712
            arrival_time: The arrival time of the request. If None, we use
713
                the current monotonic time.
714
            lora_request: The LoRA request to add.
715
            trace_headers: OpenTelemetry trace headers.
716
            prompt_adapter_request: The prompt adapter request to add.
717
718
            priority: The priority of the request.
                Only applicable with priority scheduling.
719
720
721
722

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
723
            - Create `n` number of :class:`~vllm.Sequence` objects.
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
743
        """
744
745
746
747
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

748
749
750
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
751

752
        if priority != 0 and not self.scheduler_config.policy == "priority":
753
754
755
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

756
757
758
759
760
761
762
        if isinstance(params, SamplingParams) \
            and (params.guided_decoding or params.logits_processors) \
            and self.scheduler_config.num_scheduler_steps > 1:
            raise ValueError(
                "Guided decoding and logits processors are not supported "
                "in multi-step decoding")

763
        if arrival_time is None:
764
            arrival_time = time.time()
765

766
767
768
769
770
        if self.tokenizer is not None:
            self._validate_token_prompt(
                prompt,
                tokenizer=self.get_tokenizer(lora_request=lora_request))

771
        preprocessed_inputs = self.input_preprocessor.preprocess(
772
            prompt,
773
774
            request_id=request_id,
            lora_request=lora_request,
775
776
            prompt_adapter_request=prompt_adapter_request,
        )
777
        processed_inputs = self.input_processor(preprocessed_inputs)
778

779
780
781
782
783
784
        self._add_processed_request(
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
785
            prompt_adapter_request=prompt_adapter_request,
786
            trace_headers=trace_headers,
787
            priority=priority,
788
        )
789

790
791
792
793
794
795
796
797
798
799
800
    def _validate_token_prompt(self, prompt: PromptType,
                               tokenizer: AnyTokenizer):
        # Guard against out-of-vocab tokens.
        # For some tokenizers, tokenizer.decode will happily return empty text
        # for token ids that are out of vocab, and we don't detect token ids
        # that are greater than the max token id before running the model.
        # However, these token ids will later crash a cuda kernel at runtime
        # with an index out of bounds error. This will crash the entire engine.
        # This needs to happen before multimodal input pre-processing, which
        # may add dummy <image> tokens that aren't part of the tokenizer's
        # vocabulary.
801
        if is_token_prompt(prompt):
802
803
804
805
806
807
808
809
810
            prompt_ids = prompt["prompt_token_ids"]
            if len(prompt_ids) == 0:
                # Empty prompt check is handled later
                return
            max_input_id = max(prompt_ids)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    "Token id {} is out of vocabulary".format(max_input_id))

811
812
813
814
815
    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
816
817
        arrival_time: float,
        lora_request: Optional[LoRARequest],
818
        trace_headers: Optional[Mapping[str, str]] = None,
819
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
820
        encoder_seq: Optional[Sequence] = None,
821
        priority: int = 0,
822
823
824
825
826
827
828
829
830
831
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

832
833
834
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

835
836
837
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
838

839
        sampling_params.update_from_generation_config(
840
            self.generation_config_fields, seq.eos_token_id)
841

842
        # Create the sequence group.
843
844
845
846
847
848
849
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
850
            prompt_adapter_request=prompt_adapter_request,
851
852
            encoder_seq=encoder_seq,
            priority=priority)
853

854
855
856
857
858
859
860
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
861
862
        arrival_time: float,
        lora_request: Optional[LoRARequest],
863
        prompt_adapter_request: Optional[PromptAdapterRequest],
864
        encoder_seq: Optional[Sequence] = None,
865
        priority: int = 0,
866
867
868
869
870
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
871
872
873
874
875
876
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
877
            prompt_adapter_request=prompt_adapter_request,
878
879
            encoder_seq=encoder_seq,
            priority=priority)
880
        return seq_group
881

Antoni Baum's avatar
Antoni Baum committed
882
883
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
884
885

        Args:
Antoni Baum's avatar
Antoni Baum committed
886
            request_id: The ID(s) of the request to abort.
887
888
889
890
891
892
893
894
895
896
897

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
898
        """
899
900
        for scheduler in self.scheduler:
            scheduler.abort_seq_group(request_id)
901

902
903
904
905
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

906
907
908
909
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

910
911
912
913
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

914
915
916
917
918
919
920
921
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

922
    def get_num_unfinished_requests(self) -> int:
923
        """Gets the number of unfinished requests."""
924
925
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
926

927
    def has_unfinished_requests(self) -> bool:
928
        """Returns True if there are unfinished requests."""
929
930
931
932
933
934
935
936
937
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
938

939
940
941
942
943
944
945
946
    def reset_prefix_cache(self) -> bool:
        """Reset prefix cache for all devices."""

        success = True
        for scheduler in self.scheduler:
            success = success and scheduler.reset_prefix_cache()
        return success

947
    @staticmethod
948
949
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
950
        outputs: List[PoolingSequenceGroupOutput],
951
    ) -> None:
952
        seq_group.pooled_data = outputs[0].data
953
954
955
956
957
958

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

959
960
961
962
963
964
965
966
    def _update_num_computed_tokens_for_multi_step_prefill(
            self, seq_group: SequenceGroup,
            seq_group_meta: SequenceGroupMetadata,
            is_first_step_output: Optional[bool]):
        """
        This function updates num_computed_tokens for prompt sequences
        when Multi-Step is enabled.

967
        seq_group: SequenceGroup to update the num_computed_tokens for.
968
        seq_group_meta: Metadata of the given SequenceGroup.
969
        is_first_step_output: Optional[bool] -
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
            When available, is_first_step_output indicates if the appended
            output token is the output of the first-step in multi-step.
            A value of None indicates that outputs from all steps in
            in multi-step are submitted in a single burst.
        """

        assert self.scheduler_config.is_multi_step

        if not seq_group_meta.is_prompt:
            # num_computed_token updates for multi-step decodes happen after
            # the tokens are appended to the sequence.
            return

        do_update: bool = False
        if self.scheduler_config.chunked_prefill_enabled:
            # In multi-step + chunked-prefill case, the prompt sequences
            # that are scheduled are fully processed in the first step.
            do_update = is_first_step_output is None or is_first_step_output
        else:
            # Normal multi-step decoding case. In this case prompt-sequences
            # are actually single-stepped. Always update in this case.
            assert seq_group.state.num_steps == 1
            do_update = True

        if do_update:
            seq_group.update_num_computed_tokens(
                seq_group_meta.token_chunk_size)

998
999
1000
1001
1002
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
1003

1004
1005
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
1006
        """
1007

1008
        now = time.time()
1009

1010
        if len(ctx.output_queue) == 0:
1011
1012
            return None

1013
        # Get pending async postprocessor
1014
1015
1016
1017
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1018
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
1019
1020
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1021
1022
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
1023
1024
1025
1026
1027

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

1028
        has_multiple_outputs: bool = len(outputs) > 1
1029
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
1030
1031
1032
1033
1034
        if has_multiple_outputs:
            assert self.scheduler_config.is_multi_step or \
                     self.speculative_config
            # Organize outputs by [step][sequence group] instead of
            # [sequence group][step].
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
            if self.scheduler_config.is_multi_step:
                outputs_by_sequence_group = create_output_by_sequence_group(
                    outputs, len(seq_group_metadata_list))
            elif self.speculative_config:
                # Decodes are multi-steps while prefills are not, outputting at
                # most 1 token. Separate them so that we can trigger chunk
                # processing without having to pad or copy over prompts K times
                # to match decodes structure (costly with prompt_logprobs).
                num_prefills = sum(sg.is_prompt
                                   for sg in seq_group_metadata_list)
                prefills, decodes = outputs[:num_prefills], outputs[
                    num_prefills:]
                outputs_by_sequence_group = create_output_by_sequence_group(
                    decodes,
                    num_seq_groups=len(seq_group_metadata_list) - num_prefills)
                outputs_by_sequence_group = [p.outputs for p in prefills
                                             ] + outputs_by_sequence_group
1052
1053
1054
            # We have outputs for multiple steps submitted in a single burst,
            # so invalidate is_first_step_output.
            is_first_step_output = None
1055
        elif len(outputs) == 1:
1056
            outputs_by_sequence_group = outputs
1057
1058
        else:
            return None
1059

1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

1077
        finished_before: List[int] = []
1078
        finished_now: List[int] = []
1079
        empty_seq_indices: List[int] = []
1080
1081
1082
1083
1084
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
1085
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1086

1087
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
1088
1089
1090
1091
1092

            if seq_group.is_finished():
                finished_before.append(i)
                continue

1093
            output: List[SequenceGroupOutput]
1094
            if has_multiple_outputs:
1095
1096
1097
1098
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

1099
            # tree style speculative decoding may generate empty output in first step
1100
1101
1102
1103
1104
1105
1106
1107
1108
            if self.tree_decoding and outputs and isinstance(output[0], CompletionSequenceGroupOutput):
                samples = [o.samples[0] for o in output]
                valid_samples = [
                    sample for sample in samples
                    if sample.output_token != VLLM_INVALID_TOKEN_ID
                ]
                if len(valid_samples) == 0:
                    empty_seq_indices.append(i)
                    continue
1109

1110
            if not is_async:
1111
1112
1113
1114
1115
1116
                if self.scheduler_config.is_multi_step:
                    # Updates happen only if the sequence is prefill
                    self._update_num_computed_tokens_for_multi_step_prefill(
                        seq_group, seq_group_meta, is_first_step_output)
                else:
                    seq_group.update_num_computed_tokens(
1117
                        seq_group_meta.token_chunk_size or 0)
1118
1119
1120

            if outputs:
                for o in outputs:
1121
1122
1123
1124
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
1125
                                o.model_forward_time or 0)
1126
1127
1128
1129
1130
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
1131
                                o.model_execute_time or 0)
1132
1133
1134
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1135

1136
            if self.model_config.runner_type == "pooling":
1137
                self._process_sequence_group_outputs(seq_group, output)
1138
1139
1140
1141
1142
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
                    self.output_processor.process_outputs(
                        seq_group, output, is_async)
1143

1144
1145
            if seq_group.is_finished():
                finished_now.append(i)
1146

1147
1148
1149
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1150

1151
            seq_group = scheduled_seq_group.seq_group
1152
            seq_group.maybe_set_first_token_time(now)
1153
1154
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1155
            request_output = RequestOutputFactory.create(
1156
1157
1158
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1159
1160
            if request_output:
                ctx.request_outputs.append(request_output)
1161

1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1174
1175
1176
1177
1178
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

1179
1180
        # For multi-step without streaming, don't create outputs each iteration
        if not is_last_step and not ctx.multi_step_stream_outputs:
1181
1182
1183
1184
            # Immediately process request outputs here (if callback is given)
            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
1185
                ctx.request_outputs.clear()
1186
1187
1188
            return

        # Create the outputs
1189
        for i in indices:
1190
            if i in skip or i in finished_before or i in finished_now or i in empty_seq_indices:
1191
1192
                continue  # Avoids double processing

1193
1194
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1195
            seq_group = scheduled_seq_group.seq_group
1196
            seq_group.maybe_set_first_token_time(now)
1197
1198
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1199
            request_output = RequestOutputFactory.create(
1200
1201
1202
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1203
            if request_output:
1204
                ctx.request_outputs.append(request_output)
1205

1206
1207
1208
1209
1210
1211
1212
1213
        # For multi-step with streaming, create outputs each iteration
        if not is_last_step and ctx.multi_step_stream_outputs:
            # Immediately process request outputs here (if callback is given)
            if self.process_request_outputs_callback is not None:
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1214
        for seq_group in scheduler_outputs.ignored_seq_groups:
1215
1216
1217
1218
1219
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1220
            request_output = RequestOutputFactory.create(
1221
1222
1223
1224
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
1225
1226
            if request_output:
                ctx.request_outputs.append(request_output)
1227

1228
1229
1230
1231
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1232
            ctx.request_outputs.clear()
1233

1234
1235
1236
1237
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1238
            # Log stats.
1239
1240
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1241
1242

            # Tracing
1243
            self.do_tracing(scheduler_outputs, finished_before)
1244
1245
1246

        return None

lizhigong's avatar
lizhigong committed
1247
    def _fix_last_step(
1248
            self, output: List[SamplerOutput],
lizhigong's avatar
lizhigong committed
1249
1250
1251
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        
lizhigong's avatar
lizhigong committed
1252
1253
1254
1255
1256
        #sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
        sample_out_list = self.async_d2h.tolist()
        sample_out_ids = output[0].sampler_out_ids.tolist()
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output[0], scheduled_seq_groups):
lizhigong's avatar
lizhigong committed
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

            if seq_group_metadata.do_sample:
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
lizhigong's avatar
lizhigong committed
1267
1268
1269
1270
1271
                for token_id, seq_id in zip(sample_out_list, sample_out_ids):
                    if seq.seq_id == seq_id:
                        sample.output_token = token_id[0]
                        seq.fix_last_token_id(sample.output_token)
                        break
lizhigong's avatar
lizhigong committed
1272

1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
    def _advance_to_next_step(
            self, output: List[SamplerOutput],
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1288
1289
1290
1291
1292
            if self.scheduler_config.is_multi_step:
                # Updates happen only if the sequence is prefill
                self._update_num_computed_tokens_for_multi_step_prefill(
                    seq_group, seq_group_metadata,
                    seq_group.state.num_steps == 1)
1293
            else:
1294
1295
1296
1297
                token_chunk_size = (seq_group_metadata.token_chunk_size
                                    if seq_group_metadata.token_chunk_size
                                    is not None else 0)
                seq_group.update_num_computed_tokens(token_chunk_size)
1298
1299
1300
1301

            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1302
                    " (i.e sampling_params.n == 1)")
1303
1304
1305
1306
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1307

1308
1309
1310
1311
1312
1313
1314
1315
                if self.scheduler_config.is_multi_step:
                    is_prefill_append = seq.data.get_num_uncomputed_tokens(
                    ) == 0
                    seq.append_token_id(sample.output_token, sample.logprobs)
                    if not is_prefill_append:
                        seq_group.update_num_computed_tokens(1)
                else:
                    seq.append_token_id(sample.output_token, sample.logprobs)
1316

lizhigong's avatar
lizhigong committed
1317
1318
    def trans_last_output_tensor(self, last_output) -> torch.Tensor:
        return None
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
    
    def thread_zero_overhead(self):
        while True:
            self.sem_m2s.acquire()
            virtual_engine = 0
            ctx = self.scheduler_contexts[virtual_engine]

            # Clear outputs for each new scheduler iteration
            ctx.request_outputs.clear()
            # Schedule iteration
            (seq_group_metadata_list, scheduler_outputs,
                allow_async_output_proc
                ) = self.scheduler[virtual_engine].schedule()

            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs

            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()

            assert seq_group_metadata_list is not None
            assert scheduler_outputs is not None

            last_outputs_ids = None
            last_outputs_tensor = None
            if self.last_record is not None:
                last_output = self.last_record[0][0]
                last_outputs_ids,  last_outputs_tensor = last_output.sampler_out_ids,  last_output.sampler_out_tenosr
                self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
                self.async_event.record()
                self.q_recorder.put(self.last_record)
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids, 
                last_outputs_ids = last_outputs_ids, 
                last_outputs_sample = last_outputs_tensor)
1366
1367
1368
            # if allow_async_output_proc:
            #     execute_model_req.async_callback = self.async_callbacks[
            #         virtual_engine]
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393

            #profile.ProfRangeAutoPush('model_executor')
            outputs = self.model_executor.execute_model(
                execute_model_req=execute_model_req)

            self._advance_to_next_step(
                outputs[0], seq_group_metadata_list,
                scheduler_outputs.scheduled_seq_groups)
            self.last_record = [outputs, seq_group_metadata_list, scheduler_outputs]

    def zero_overhead_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
        self.sem_m2s.release()
        recode_output = self.q_recorder.get()
        if recode_output is None: # None is for the first step
            return None
        virtual_engine = 0
        ctx = self.scheduler_contexts[virtual_engine]
        outputs, seq_group_metadata_list, scheduler_outputs = recode_output
        ctx.seq_group_metadata_list = seq_group_metadata_list
        ctx.scheduler_outputs = scheduler_outputs
        self.async_event.synchronize()
        self._fix_last_step(
            outputs, seq_group_metadata_list,
            scheduler_outputs.scheduled_seq_groups)

1394
1395
1396
1397
1398
        # is_first_step_output is True only when the num_steps of all
        # the sequences are 1. When the num_steps > 1,
        # multi_step_model_runner does the first-step output append.
        is_first_step_output: bool = False if not seq_group_metadata_list \
            else seq_group_metadata_list[0].state.num_steps == 1
1399

1400
1401
1402
1403
1404
1405
1406
        # Add results to the output_queue
        ctx.append_output(outputs=outputs,
                            seq_group_metadata_list=seq_group_metadata_list,
                            scheduler_outputs=scheduler_outputs,
                            is_async=True,
                            is_last_step=True,
                            is_first_step_output=is_first_step_output)
1407

1408
1409
1410
1411
1412
1413
1414
1415
1416
        # Check if need to run the usual non-async path
        #if not allow_async_output_proc:
        self._process_model_outputs(ctx=ctx)

        # Log stats.
        self.do_log_stats(scheduler_outputs, outputs)

        # Tracing
        self.do_tracing(scheduler_outputs)
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432

        #profile.ProfRangeAutoPush('has_unfinish')
        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
            assert len(ctx.output_queue) == 0

            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
            logger.debug("Stopping remote worker execution loop.")
            self.model_executor.stop_remote_worker_execution_loop()
        return ctx.request_outputs
lizhigong's avatar
lizhigong committed
1433

1434
    def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1435
1436
        """Performs one decoding iteration and returns newly generated results.

1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

1452
            - Step 2: Calls the distributed executor to execute the model.
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
1474
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
1475
1476
1477
1478
1479
1480
1481
1482
1483
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
1484
        """
1485
1486
1487
        if self.zero_overhead:
            return self.zero_overhead_step()

1488
1489
1490
1491
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1492

1493
        # For llm_engine, there is no pipeline parallel support, so the engine
1494
        # used is always 0.
1495
1496
        virtual_engine = 0

1497
1498
        # These are cached outputs from previous iterations. None if on first
        # iteration
1499
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1500
1501
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1502
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1503

1504
1505
        ctx = self.scheduler_contexts[virtual_engine]

1506
1507
1508
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1509
1510
1511
1512
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
1513
            # Schedule iteration
1514
            (seq_group_metadata_list, scheduler_outputs,
1515
1516
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1517

1518
1519
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1520

1521
1522
1523
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()

1524
1525
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1526
                self._process_model_outputs(ctx=ctx)
1527

1528
1529
1530
1531
1532
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
1533
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
1534
                    allow_async_output_proc)
1535
1536
        else:
            finished_requests_ids = list()
1537
1538
1539

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1540

1541
        if not scheduler_outputs.is_empty():
1542
1543
1544
1545
1546
1547

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1548
                self._get_last_sampled_token_ids(virtual_engine)
1549

1550
            execute_model_req = ExecuteModelRequest(
1551
1552
1553
1554
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1555
1556
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1557
1558
1559
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
1560
                last_sampled_token_ids=last_sampled_token_ids)
1561
            if allow_async_output_proc:
1562
                execute_model_req.async_callback = self.async_callbacks[
1563
                    virtual_engine]
1564

1565
            #profile.ProfRangeAutoPush('model_executor')
1566
            outputs = self.model_executor.execute_model(
1567
                execute_model_req=execute_model_req)
1568
            #profile.ProfRangeAutoPush('end_executor')
1569
            # We need to do this here so that last step's sampled_token_ids can
1570
1571
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
1572
                self._update_cached_scheduler_output(virtual_engine, outputs)
1573
        else:
1574
1575
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1576
1577
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1578
            # No outputs in this case
1579
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1580

lizhigong's avatar
lizhigong committed
1581
1582
1583
1584
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()
lizhigong's avatar
lizhigong committed
1585

1586
        if not self._has_remaining_steps(seq_group_metadata_list):
1587
            # clear the cache if we have finished all the steps.
1588
1589
1590
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()

1591
1592
1593
1594
1595
1596
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1597
            # Add results to the output_queue
1598
1599
1600
1601
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1602
1603
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1604
1605
1606

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1607
                    "Async postprocessor expects only a single output set")
1608
1609
1610
                self._advance_to_next_step(
                    outputs[0], seq_group_metadata_list,
                    scheduler_outputs.scheduled_seq_groups)
1611

1612
            # Check if need to run the usual non-async path
1613
            if not allow_async_output_proc:
1614
                self._process_model_outputs(ctx=ctx)
1615

1616
                # Log stats.
1617
                self.do_log_stats(scheduler_outputs, outputs)
1618

1619
1620
1621
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1622
            # Multi-step case
1623
            return ctx.request_outputs
1624

1625
        #profile.ProfRangeAutoPush('has_unfinish')
1626
        if not self.has_unfinished_requests():
1627
1628
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1629
                self._process_model_outputs(ctx=ctx)
1630
            assert len(ctx.output_queue) == 0
1631

1632
1633
1634
1635
1636
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1637
            logger.debug("Stopping remote worker execution loop.")
1638
            self.model_executor.stop_remote_worker_execution_loop()
1639
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1640

1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
1656
1657
            raise AssertionError("All running sequence groups should "
                                 "have the same remaining steps.")
1658
1659
1660
1661
1662
1663

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1664
1665
1666
1667
1668
1669
1670
1671
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None
Antoni Baum's avatar
Antoni Baum committed
1696

1697
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1698
1699
1700
1701
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1702
1703
1704
1705
1706
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1707
1708
1709
1710
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1711
1712
1713
1714
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1715
1716
1717
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1718
1719
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1720
1721
        """Forced log when no requests active."""
        if self.log_stats:
1722
            stats = self._get_stats(scheduler_outputs, model_output,
1723
                                    finished_before, skip)
1724
            for logger in self.stat_loggers.values():
1725
                logger.log(stats)
1726

1727
1728
1729
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1730
1731
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1732
1733
1734
1735
1736
1737
1738
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1739
1740
1741
1742
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1743
        """
1744
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1745

1746
1747
        # System State
        #   Scheduler State
1748
1749
1750
1751
1752
1753
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1754
1755

        # KV Cache Usage in %
1756
        num_total_gpu = self.cache_config.num_gpu_blocks
1757
        gpu_cache_usage_sys = 0.
1758
        if num_total_gpu:  # Guard against both None and 0
1759
1760
1761
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1762
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1763

1764
        num_total_cpu = self.cache_config.num_cpu_blocks
1765
        cpu_cache_usage_sys = 0.
1766
        if num_total_cpu:  # Guard against both None and 0
1767
1768
1769
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1770
1771
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1772
1773
1774
1775
1776
1777
1778
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1779
1780
1781
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1782
        num_tokens_iter = 0
1783
1784
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1785
1786
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1787
1788
1789
1790

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1791
1792
1793
1794
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1795
1796
1797
        time_in_queue_requests: List[float] = []
        model_forward_time_requests: List[float] = []
        model_execute_time_requests: List[float] = []
1798
1799
1800
1801
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1802
        max_num_generation_tokens_requests: List[int] = []
1803
        max_tokens_requests: List[int] = []
1804
1805
        finished_reason_requests: List[str] = []

1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
        # Lora requests
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1825
1826
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1827
        if scheduler_outputs is not None:
1828
1829
1830
1831
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1832
            num_generation_tokens_from_prefill_groups = 0
1833
1834
1835
1836
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1837
1838
1839

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1840
1841
1842
1843
1844
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue

1845
1846
1847
1848
                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1849

1850
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1851
                seq_group = scheduled_seq_group.seq_group
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
1864
                        latency = seq_group.get_last_token_latency()
1865
1866
1867
1868
1869
1870
1871
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
1872
                    latency = seq_group.get_last_token_latency()
1873
                    time_per_output_tokens_iter.append(latency)
1874
1875
1876
1877
1878
1879
1880
1881
1882
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1883
1884
1885
1886
1887
1888

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1889
                if seq_group.is_finished():
1890
                    # Latency timings
1891
1892
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1905
1906
1907
1908
1909
1910
1911
1912
1913
                    if seq_group.metrics.time_in_queue is not None:
                        time_in_queue_requests.append(
                            seq_group.metrics.time_in_queue)
                    if seq_group.metrics.model_forward_time is not None:
                        model_forward_time_requests.append(
                            seq_group.metrics.model_forward_time)
                    if seq_group.metrics.model_execute_time is not None:
                        model_execute_time_requests.append(
                            seq_group.metrics.model_execute_time * 1000)
1914
1915
1916
1917
1918
1919
1920
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1921
1922
1923
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1924
1925
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1926
1927
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1940
                actual_num_batched_tokens - num_prompt_tokens_iter +
1941
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1942
1943
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1944
1945
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
1946
1947
        if model_output and isinstance(model_output[0], SamplerOutput) and (
                model_output[0].spec_decode_worker_metrics is not None):
1948
1949
1950
1951
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1952
1953
        return Stats(
            now=now,
1954
1955
1956
1957
1958
1959
1960
1961
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1962
1963
1964
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1965
1966
1967
1968

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1969
            num_tokens_iter=num_tokens_iter,
1970
1971
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1972
            spec_decode_metrics=spec_decode_metrics,
1973
            num_preemption_iter=num_preemption_iter,
1974
1975
1976
1977

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1978
1979
1980
1981
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1982
1983
1984
            time_in_queue_requests=time_in_queue_requests,
            model_forward_time_requests=model_forward_time_requests,
            model_execute_time_requests=model_execute_time_requests,
1985
1986
1987
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1988
1989
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
1990
            n_requests=n_requests,
1991
            max_tokens_requests=max_tokens_requests,
1992
            finished_reason_requests=finished_reason_requests,
1993
1994
1995
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1996

1997
    def add_lora(self, lora_request: LoRARequest) -> bool:
1998
        return self.model_executor.add_lora(lora_request)
1999
2000

    def remove_lora(self, lora_id: int) -> bool:
2001
        return self.model_executor.remove_lora(lora_id)
2002

2003
    def list_loras(self) -> Set[int]:
2004
        return self.model_executor.list_loras()
2005

2006
2007
2008
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

2019
2020
2021
2022
2023
2024
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
    def sleep(self, level: int = 1) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.sleep(level=level)

    def wake_up(self) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.wake_up()

2035
    def check_health(self) -> None:
2036
2037
        if self.tokenizer:
            self.tokenizer.check_health()
2038
        self.model_executor.check_health()
2039
2040
2041
2042

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

2043
2044
2045
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
2046
2047
2048
        if self.tracer is None:
            return

2049
2050
2051
2052
2053
2054
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
2074
            seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
2075
                                   self.model_config.model)
2076
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
2077
                                   seq_group.request_id)
2078
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
2079
                                   seq_group.sampling_params.temperature)
2080
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
2081
                                   seq_group.sampling_params.top_p)
2082
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
2083
                                   seq_group.sampling_params.max_tokens)
2084
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
2085
                                   seq_group.sampling_params.n)
2086
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
2087
                                   seq_group.num_seqs())
2088
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
2089
2090
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
2091
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
2092
2093
2094
2095
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
2096
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
2097
2098
                                   metrics.time_in_queue)
            seq_span.set_attribute(
2099
2100
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
2101
2102
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
2103
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
2104
2105
2106
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
2107
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
2108
2109
2110
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
2111
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
2112
                    metrics.model_execute_time)
2113

2114
    def _validate_model_inputs(self, inputs: ProcessorInputs,
2115
                               lora_request: Optional[LoRARequest]):
2116
        if is_encoder_decoder_inputs(inputs):
2117
2118
            # For encoder-decoder multimodal models, the max_prompt_len
            # restricts the decoder prompt length
2119
2120
            prompt_inputs = inputs["decoder" if self.model_config.
                                   is_multimodal_model else "encoder"]
2121
        else:
2122
2123
            prompt_inputs = inputs

2124
        prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
2125
2126

        if prompt_ids is None or len(prompt_ids) == 0:
2127
            raise ValueError("Prompt cannot be empty")
2128

2129
        if self.model_config.is_multimodal_model:
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
            max_prompt_len = self.model_config.max_model_len

            if len(prompt_ids) > max_prompt_len:
                raise ValueError(
                    f"The prompt (total length {len(prompt_ids)}) is too long "
                    f"to fit into the model (context length {max_prompt_len}). "
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
2140
2141
2142
2143

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
        """Constructs logits processors based on the guided_decoding,
        logits_bias, and allowed_token_ids fields in sampling_params. Deletes
        those fields and adds the constructed logits processors to the
        logits_processors field. Returns the modified sampling params."""

        logits_processors = []
2154

2155
2156
2157
2158
2159
        if sampling_params.guided_decoding is not None:
            # Defensively copy sampling params since guided decoding logits
            # processors can have different state for each request
            sampling_params = copy.copy(sampling_params)
            guided_decoding = sampling_params.guided_decoding
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169

            logger.debug(
                "Building guided decoding logits processor in "
                "LLMEngine. Params: %s", guided_decoding)

            tokenizer = self.get_tokenizer(lora_request=lora_request)
            guided_decoding.backend = guided_decoding.backend or \
                self.decoding_config.guided_decoding_backend

            processor = get_local_guided_decoding_logits_processor(
2170
2171
2172
                guided_params=guided_decoding,
                tokenizer=tokenizer,
                model_config=self.model_config)
2173
2174
2175
2176
2177
2178
2179
2180
2181
            if processor:
                logits_processors.append(processor)

            # Unset so this doesn't get passed down to the model
            sampling_params.guided_decoding = None

        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

2182
            processors = get_openai_logits_processors(
2183
2184
2185
2186
2187
2188
2189
2190
2191
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

2192
2193
2194
2195
2196
2197
        if len(sampling_params.bad_words) > 0:
            tokenizer = self.get_tokenizer(lora_request)
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

2198
2199
2200
2201
2202
2203
2204
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params