llm_engine.py 85 KB
Newer Older
1
import copy
Antoni Baum's avatar
Antoni Baum committed
2
import time
3
from collections import Counter as collectionsCounter
4
from collections import deque
5
from contextlib import contextmanager
6
from dataclasses import dataclass
7
from functools import partial
8
9
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
                    List, Mapping, NamedTuple, Optional)
10
from typing import Sequence as GenericSequence
11
from typing import Set, Type, Union, cast, overload
12

13
import torch
14
from typing_extensions import TypeVar, deprecated
15

16
import vllm.envs as envs
17
18
19
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ObservabilityConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
20
21
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.engine.arg_utils import EngineArgs
23
from vllm.engine.metrics_types import StatLoggerBase, Stats
24
25
26
27
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
28
29
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
30
from vllm.executor.executor_base import ExecutorBase
31
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
32
                         PromptType, SingletonInputsAdapter)
33
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
34
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.logger import init_logger
36
from vllm.logits_process import get_bad_words_logits_processors
37
from vllm.lora.request import LoRARequest
38
39
from vllm.model_executor.guided_decoding import (
    get_local_guided_decoding_logits_processor)
40
from vllm.model_executor.layers.sampler import SamplerOutput
41
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
42
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
43
44
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
45
from vllm.prompt_adapter.request import PromptAdapterRequest
46
from vllm.sampling_params import RequestOutputKind, SamplingParams
47
48
49
50
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
                           PoolingSequenceGroupOutput, Sequence, SequenceGroup,
                           SequenceGroupBase, SequenceGroupMetadata,
                           SequenceGroupOutput, SequenceStatus)
51
52
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
53
from vllm.transformers_utils.detokenizer import Detokenizer
54
from vllm.transformers_utils.tokenizer import AnyTokenizer
55
from vllm.transformers_utils.tokenizer_group import (
56
    BaseTokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
57
58
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
59
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
60
from vllm.version import __version__ as VLLM_VERSION
61
62

logger = init_logger(__name__)
63
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
64

65
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
66
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
67
68


69
70
71
72
73
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
74
75
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
76
77


78
79
80
81
82
83
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
84
85
86
87
88
89
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
90
91
92
    skip: List[int]


93
class SchedulerContext:
94

95
    def __init__(self, multi_step_stream_outputs: bool = False):
96
97
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
98
                                         PoolingRequestOutput]] = []
99
100
101
102
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

103
104
        self.multi_step_stream_outputs: bool = multi_step_stream_outputs

105
106
107
    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
108
109
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
110
111
112
113
114
115
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
116
                       is_first_step_output=is_first_step_output,
117
                       skip=[]))
118
119


120
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
121
    """An LLM engine that receives requests and generates texts.
122

Woosuk Kwon's avatar
Woosuk Kwon committed
123
    This is the main class for the vLLM engine. It receives requests
124
125
126
127
128
129
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

130
131
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
132

133
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
134
    :ref:`engine-args`)
135
136
137
138
139
140
141

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
142
        device_config: The configuration related to the device.
143
144
145
        lora_config (Optional): The configuration related to serving multi-LoRA.
        speculative_config (Optional): The configuration related to speculative
            decoding.
146
147
        executor_class: The model executor class for managing distributed
            execution.
148
        prompt_adapter_config (Optional): The configuration related to serving
149
            prompt adapters.
150
        log_stats: Whether to log statistics.
151
        usage_context: Specified entry point, used for usage info collection.
152
    """
153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

179
        return cast(_O, output)
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

205
206
    def __init__(
        self,
207
        vllm_config: VllmConfig,
208
        executor_class: Type[ExecutorBase],
209
        log_stats: bool,
yhu422's avatar
yhu422 committed
210
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
211
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
212
        input_registry: InputRegistry = INPUT_REGISTRY,
213
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
214
        use_cached_outputs: bool = False,
215
    ) -> None:
216

217
        self.vllm_config = vllm_config
218
219
220
221
222
223
224
225
226
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
227
        )
228
229
        self.prompt_adapter_config = vllm_config.prompt_adapter_config  # noqa
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
230
231
        )

232
        logger.info(
233
            "Initializing an LLM engine (v%s) with config: %s, "
234
            "use_cached_outputs=%s, ",
235
            VLLM_VERSION,
236
            vllm_config,
237
            use_cached_outputs,
238
        )
239

240
        self.log_stats = log_stats
241
        self.use_cached_outputs = use_cached_outputs
242

243
        if not self.model_config.skip_tokenizer_init:
244
            self.tokenizer = self._init_tokenizer()
245
            self.detokenizer = Detokenizer(self.tokenizer)
246
            tokenizer_group = self.get_tokenizer_group()
247
248
        else:
            self.tokenizer = None
249
            self.detokenizer = None
250
251
252
253
254
255
256
257
            tokenizer_group = None

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
258

259
        self.seq_counter = Counter()
260
261
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
262

263
        self.input_preprocessor = InputPreprocessor(self.model_config,
264
265
                                                    self.tokenizer,
                                                    mm_registry)
266

267
268
        self.input_registry = input_registry
        self.input_processor = input_registry.create_input_processor(
269
            self.model_config)
270

271
        self.model_executor = executor_class(vllm_config=vllm_config, )
272

273
        if self.model_config.runner_type != "pooling":
274
            self._initialize_kv_caches()
275

yhu422's avatar
yhu422 committed
276
277
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
278
279
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
280
            usage_message.report_usage(
281
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
282
283
284
285
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
286
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
287
                    "tensor_parallel_size":
288
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
289
                    "block_size":
290
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
291
                    "gpu_memory_utilization":
292
                    self.cache_config.gpu_memory_utilization,
yhu422's avatar
yhu422 committed
293
294
295

                    # Quantization
                    "quantization":
296
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
297
                    "kv_cache_dtype":
298
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
299
300
301

                    # Feature flags
                    "enable_lora":
302
                    bool(self.lora_config),
303
                    "enable_prompt_adapter":
304
                    bool(self.prompt_adapter_config),
yhu422's avatar
yhu422 committed
305
                    "enable_prefix_caching":
306
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
307
                    "enforce_eager":
308
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
309
                    "disable_custom_all_reduce":
310
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
311
312
                })

313
314
315
316
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
317

318
319
320
321
322
323
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
324
325
            SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
                             multi_step_stream_outputs)
326
327
328
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

329
        if self.model_config.use_async_output_proc:
330
331
332
333
334
335
336
337
338
            process_model_outputs = weak_bind(self._process_model_outputs)

            self.async_callbacks = [
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
        else:
            self.async_callbacks = []
339
340
341

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
342
        self.process_request_outputs_callback: Optional[Callable] = None
343

344
        # Create the scheduler.
345
346
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
347
        self.scheduler = [
348
            Scheduler(
349
350
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
351
                self.async_callbacks[v_id]
352
353
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
354
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
355

356
357
        # Metric Logging.
        if self.log_stats:
358
359
360
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
361
362
363
364
365
366
367
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

368
369
370
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
371
372
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        vllm_config=vllm_config),
373
374
375
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
376
377
                        labels=dict(
                            model_name=self.model_config.served_model_name),
378
                        vllm_config=vllm_config),
379
380
381
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
382

383
384
385
386
387
388
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

389
390
391
392
393
394
395
396
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
397
                get_tokenizer_for_seq,
398
399
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
400
                    get_tokenizer_for_seq,
401
402
403
                ),
            ))

404
405
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

406
407
408
409
410
411
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
412
        start = time.time()
413
414
415
416
417
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
418
419
420
421
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
422
423
424
425
426
427
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
428
429
430
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
431

432
    @classmethod
433
    def _get_executor_cls(cls,
434
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
435
436
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
437
        # Initialize the cluster and specify the executor class.
438
439
440
441
442
443
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
444
        elif engine_config.parallel_config.world_size > 1:
445
            if distributed_executor_backend == "ray":
446
447
448
                from vllm.executor.ray_distributed_executor import (
                    RayDistributedExecutor)
                executor_class = RayDistributedExecutor
449
            elif distributed_executor_backend == "mp":
450
451
452
453
454
455
456
457
458
459
                from vllm.executor.mp_distributed_executor import (
                    MultiprocessingDistributedExecutor)
                assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                    "multiprocessing distributed executor backend does not "
                    "support VLLM_USE_RAY_SPMD_WORKER=1")
                executor_class = MultiprocessingDistributedExecutor
            elif distributed_executor_backend == "uni":
                # JAX-style, single-process, multi-device executor.
                from vllm.executor.uniproc_executor import UniProcExecutor
                executor_class = UniProcExecutor
460
461
462
463
464
            elif distributed_executor_backend == "external_launcher":
                # executor with external launcher
                from vllm.executor.uniproc_executor import (  # noqa
                    ExecutorWithExternalLauncher)
                executor_class = ExecutorWithExternalLauncher
465
        else:
466
467
            from vllm.executor.uniproc_executor import UniProcExecutor
            executor_class = UniProcExecutor
468
469
470
471
472
473
474
475
476
477
478
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
479
        engine_config = engine_args.create_engine_config(usage_context)
480
        executor_class = cls._get_executor_cls(engine_config)
481
        # Create the LLM engine.
yhu422's avatar
yhu422 committed
482
        engine = cls(
483
            vllm_config=engine_config,
yhu422's avatar
yhu422 committed
484
485
486
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
487
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
488
        )
489

490
        return engine
491

492
493
494
495
496
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

497
498
499
500
501
502
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

503
    def get_tokenizer_group(
504
505
506
507
508
509
        self,
        group_type: Type[_G] = BaseTokenizerGroup,
    ) -> _G:
        tokenizer_group = self.tokenizer

        if tokenizer_group is None:
510
511
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
512
513
514
515
        if not isinstance(tokenizer_group, group_type):
            raise TypeError("Invalid type of tokenizer group. "
                            f"Expected type: {group_type}, but "
                            f"found type: {type(tokenizer_group)}")
516

517
        return tokenizer_group
518

519
    def get_tokenizer(
520
521
522
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
523
        return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
524

525
526
527
528
529
    def _init_tokenizer(self) -> BaseTokenizerGroup:
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
            parallel_config=self.parallel_config,
530
            lora_config=self.lora_config)
531

532
533
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
534
        self.cache_config.verify_with_parallel_config(self.parallel_config)
535
536
537
538
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
539
540
541
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
542

543
544
545
    def _add_processed_request(
        self,
        request_id: str,
546
        processed_inputs: ProcessorInputs,
547
548
549
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
550
        prompt_adapter_request: Optional[PromptAdapterRequest],
551
        trace_headers: Optional[Mapping[str, str]] = None,
552
        priority: int = 0,
553
    ) -> Optional[SequenceGroup]:
554
555
556
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
557
558
559
560
561
562
563
564
565
566
567
568
569
570
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
            )
            return None

571
        self._validate_model_inputs(processed_inputs, lora_request)
572
573
574
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
575
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
576

577
578
579
580
581
582
583
584
        if is_encoder_decoder_inputs(processed_inputs):
            decoder_inputs = processed_inputs["decoder"]
            encoder_inputs = processed_inputs["encoder"]
        else:
            decoder_inputs = processed_inputs
            encoder_inputs = None

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
585
                       lora_request, prompt_adapter_request)
586

587
588
589
        encoder_seq = (None if encoder_inputs is None else Sequence(
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
            prompt_adapter_request))
590

591
592
593
594
595
596
597
598
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
599
                trace_headers=trace_headers,
600
                prompt_adapter_request=prompt_adapter_request,
601
602
                encoder_seq=encoder_seq,
                priority=priority)
603
604
605
606
607
608
609
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
610
                prompt_adapter_request=prompt_adapter_request,
611
612
                encoder_seq=encoder_seq,
                priority=priority)
613
614
615
616
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

617
618
619
620
621
622
623
624
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

625
626
        return seq_group

627
628
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
629

630
    @overload
631
632
633
    def add_request(
        self,
        request_id: str,
634
        prompt: PromptType,
635
        params: Union[SamplingParams, PoolingParams],
636
        arrival_time: Optional[float] = None,
637
        lora_request: Optional[LoRARequest] = None,
638
        trace_headers: Optional[Mapping[str, str]] = None,
639
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
640
        priority: int = 0,
641
    ) -> None:
642
643
644
        ...

    @overload
645
    @deprecated("'inputs' will be renamed to 'prompt")
646
647
648
    def add_request(
        self,
        request_id: str,
649
650
        *,
        inputs: PromptType,
651
652
653
654
655
656
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
657
    ) -> None:
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    def add_request(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
            priority: int = 0,
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
676
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
677
        """Add a request to the engine's request pool.
678
679

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
680
        scheduler as `engine.step()` is called. The exact scheduling policy is
681
682
683
684
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
685
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
686
687
688
689
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
690
            arrival_time: The arrival time of the request. If None, we use
691
                the current monotonic time.
692
            lora_request: The LoRA request to add.
693
            trace_headers: OpenTelemetry trace headers.
694
            prompt_adapter_request: The prompt adapter request to add.
695
696
            priority: The priority of the request.
                Only applicable with priority scheduling.
697
698
699
700

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
701
            - Create `n` number of :class:`~vllm.Sequence` objects.
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
721
        """
722
723
724
725
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

726
727
728
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
729

730
        if priority != 0 and not self.scheduler_config.policy == "priority":
731
732
733
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

734
735
736
737
738
739
740
        if isinstance(params, SamplingParams) \
            and (params.guided_decoding or params.logits_processors) \
            and self.scheduler_config.num_scheduler_steps > 1:
            raise ValueError(
                "Guided decoding and logits processors are not supported "
                "in multi-step decoding")

741
        if arrival_time is None:
742
            arrival_time = time.time()
743

744
745
746
747
748
        if self.tokenizer is not None:
            self._validate_token_prompt(
                prompt,
                tokenizer=self.get_tokenizer(lora_request=lora_request))

749
        preprocessed_inputs = self.input_preprocessor.preprocess(
750
            prompt,
751
752
            request_id=request_id,
            lora_request=lora_request,
753
754
            prompt_adapter_request=prompt_adapter_request,
        )
755
        processed_inputs = self.input_processor(preprocessed_inputs)
756

757
        self._add_processed_request(
758
759
760
761
762
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
763
            prompt_adapter_request=prompt_adapter_request,
764
            trace_headers=trace_headers,
765
            priority=priority,
766
        )
767

768
769
770
771
772
773
774
775
776
777
778
    def _validate_token_prompt(self, prompt: PromptType,
                               tokenizer: AnyTokenizer):
        # Guard against out-of-vocab tokens.
        # For some tokenizers, tokenizer.decode will happily return empty text
        # for token ids that are out of vocab, and we don't detect token ids
        # that are greater than the max token id before running the model.
        # However, these token ids will later crash a cuda kernel at runtime
        # with an index out of bounds error. This will crash the entire engine.
        # This needs to happen before multimodal input pre-processing, which
        # may add dummy <image> tokens that aren't part of the tokenizer's
        # vocabulary.
779
        if is_token_prompt(prompt):
780
781
782
783
784
785
786
787
788
            prompt_ids = prompt["prompt_token_ids"]
            if len(prompt_ids) == 0:
                # Empty prompt check is handled later
                return
            max_input_id = max(prompt_ids)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    "Token id {} is out of vocabulary".format(max_input_id))

789
790
791
792
793
    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
794
795
        arrival_time: float,
        lora_request: Optional[LoRARequest],
796
        trace_headers: Optional[Mapping[str, str]] = None,
797
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
798
        encoder_seq: Optional[Sequence] = None,
799
        priority: int = 0,
800
801
802
803
804
805
806
807
808
809
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

810
811
812
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

813
814
815
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
816

817
        sampling_params.update_from_generation_config(
818
            self.generation_config_fields, seq.eos_token_id)
819

820
        # Create the sequence group.
821
822
823
824
825
826
827
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            sampling_params=sampling_params,
            lora_request=lora_request,
            trace_headers=trace_headers,
828
            prompt_adapter_request=prompt_adapter_request,
829
830
            encoder_seq=encoder_seq,
            priority=priority)
831

832
833
834
835
836
837
838
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
839
840
        arrival_time: float,
        lora_request: Optional[LoRARequest],
841
        prompt_adapter_request: Optional[PromptAdapterRequest],
842
        encoder_seq: Optional[Sequence] = None,
843
        priority: int = 0,
844
845
846
847
848
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
849
850
851
852
853
854
        seq_group = SequenceGroup(
            request_id=request_id,
            seqs=[seq],
            arrival_time=arrival_time,
            lora_request=lora_request,
            pooling_params=pooling_params,
855
            prompt_adapter_request=prompt_adapter_request,
856
857
            encoder_seq=encoder_seq,
            priority=priority)
858
        return seq_group
859

Antoni Baum's avatar
Antoni Baum committed
860
861
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
862
863

        Args:
Antoni Baum's avatar
Antoni Baum committed
864
            request_id: The ID(s) of the request to abort.
865
866
867
868
869
870
871
872
873
874
875

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
876
        """
877
878
        for scheduler in self.scheduler:
            scheduler.abort_seq_group(request_id)
879

880
881
882
883
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

884
885
886
887
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

888
889
890
891
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

892
893
894
895
896
897
898
899
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

900
    def get_num_unfinished_requests(self) -> int:
901
        """Gets the number of unfinished requests."""
902
903
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
904

905
    def has_unfinished_requests(self) -> bool:
906
        """Returns True if there are unfinished requests."""
907
908
909
910
911
912
913
914
915
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
916

917
    @staticmethod
918
919
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
920
        outputs: List[PoolingSequenceGroupOutput],
921
    ) -> None:
922
        seq_group.pooled_data = outputs[0].data
923
924
925
926
927
928

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

929
930
931
932
933
934
935
936
    def _update_num_computed_tokens_for_multi_step_prefill(
            self, seq_group: SequenceGroup,
            seq_group_meta: SequenceGroupMetadata,
            is_first_step_output: Optional[bool]):
        """
        This function updates num_computed_tokens for prompt sequences
        when Multi-Step is enabled.

937
        seq_group: SequenceGroup to update the num_computed_tokens for.
938
        seq_group_meta: Metadata of the given SequenceGroup.
939
        is_first_step_output: Optional[bool] -
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
            When available, is_first_step_output indicates if the appended
            output token is the output of the first-step in multi-step.
            A value of None indicates that outputs from all steps in
            in multi-step are submitted in a single burst.
        """

        assert self.scheduler_config.is_multi_step

        if not seq_group_meta.is_prompt:
            # num_computed_token updates for multi-step decodes happen after
            # the tokens are appended to the sequence.
            return

        do_update: bool = False
        if self.scheduler_config.chunked_prefill_enabled:
            # In multi-step + chunked-prefill case, the prompt sequences
            # that are scheduled are fully processed in the first step.
            do_update = is_first_step_output is None or is_first_step_output
        else:
            # Normal multi-step decoding case. In this case prompt-sequences
            # are actually single-stepped. Always update in this case.
            assert seq_group.state.num_steps == 1
            do_update = True

        if do_update:
            seq_group.update_num_computed_tokens(
                seq_group_meta.token_chunk_size)

968
969
970
971
972
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
973

974
975
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
976
        """
977

978
        now = time.time()
979

980
        if len(ctx.output_queue) == 0:
981
982
            return None

983
        # Get pending async postprocessor
984
985
986
987
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
988
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
989
990
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
991
992
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
993
994
995
996
997

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

998
        has_multiple_outputs: bool = len(outputs) > 1
999
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
1000
1001
1002
1003
1004
        if has_multiple_outputs:
            assert self.scheduler_config.is_multi_step or \
                     self.speculative_config
            # Organize outputs by [step][sequence group] instead of
            # [sequence group][step].
1005
1006
            outputs_by_sequence_group = create_output_by_sequence_group(
                outputs, num_seq_groups=len(seq_group_metadata_list))
1007
1008
1009
            # We have outputs for multiple steps submitted in a single burst,
            # so invalidate is_first_step_output.
            is_first_step_output = None
1010
1011
1012
        else:
            outputs_by_sequence_group = outputs

1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

1030
        finished_before: List[int] = []
1031
        finished_now: List[int] = []
1032
1033
1034
1035
1036
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
1037
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1038

1039
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
1040
1041
1042
1043
1044

            if seq_group.is_finished():
                finished_before.append(i)
                continue

1045
            output: List[SequenceGroupOutput]
1046
            if has_multiple_outputs:
1047
1048
1049
1050
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

1051
1052
1053
1054
1055
1056
1057
            if not is_async:
                if self.scheduler_config.is_multi_step:
                    # Updates happen only if the sequence is prefill
                    self._update_num_computed_tokens_for_multi_step_prefill(
                        seq_group, seq_group_meta, is_first_step_output)
                else:
                    seq_group.update_num_computed_tokens(
1058
                        seq_group_meta.token_chunk_size or 0)
1059
1060
1061

            if outputs:
                for o in outputs:
1062
1063
1064
1065
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
1066
                                o.model_forward_time or 0)
1067
1068
1069
1070
1071
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
1072
                                o.model_execute_time or 0)
1073
1074
1075
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
1076

1077
            if self.model_config.runner_type == "pooling":
1078
                self._process_sequence_group_outputs(seq_group, output)
1079
1080
1081
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
1082
                    self.output_processor.process_outputs(
1083
                        seq_group, output, is_async)
1084

1085
1086
            if seq_group.is_finished():
                finished_now.append(i)
1087

1088
1089
1090
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1091

1092
1093
            seq_group = scheduled_seq_group.seq_group
            seq_group.maybe_set_first_token_time(now)
1094
1095
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1096
            request_output = RequestOutputFactory.create(
1097
1098
1099
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1100
1101
            if request_output:
                ctx.request_outputs.append(request_output)
1102

1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1115
1116
1117
1118
1119
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

1120
1121
        # For multi-step without streaming, don't create outputs each iteration
        if not is_last_step and not ctx.multi_step_stream_outputs:
1122
1123
1124
1125
            # Immediately process request outputs here (if callback is given)
            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
1126
                ctx.request_outputs.clear()
1127
1128
1129
            return

        # Create the outputs
1130
1131
        for i in indices:
            if i in skip or i in finished_before or i in finished_now:
1132
1133
                continue  # Avoids double processing

1134
1135
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1136
            seq_group = scheduled_seq_group.seq_group
1137
            seq_group.maybe_set_first_token_time(now)
1138
1139
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1140
            request_output = RequestOutputFactory.create(
1141
1142
1143
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1144
            if request_output:
1145
                ctx.request_outputs.append(request_output)
1146

1147
1148
1149
1150
1151
1152
1153
1154
        # For multi-step with streaming, create outputs each iteration
        if not is_last_step and ctx.multi_step_stream_outputs:
            # Immediately process request outputs here (if callback is given)
            if self.process_request_outputs_callback is not None:
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1155
        for seq_group in scheduler_outputs.ignored_seq_groups:
1156
1157
1158
1159
1160
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1161
            request_output = RequestOutputFactory.create(
1162
1163
1164
1165
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
1166
1167
            if request_output:
                ctx.request_outputs.append(request_output)
1168

1169
1170
1171
1172
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1173
            ctx.request_outputs.clear()
1174

1175
1176
1177
1178
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1179
            # Log stats.
1180
1181
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1182
1183

            # Tracing
1184
            self.do_tracing(scheduler_outputs, finished_before)
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202

        return None

    def _advance_to_next_step(
            self, output: List[SamplerOutput],
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1203
1204
1205
1206
1207
            if self.scheduler_config.is_multi_step:
                # Updates happen only if the sequence is prefill
                self._update_num_computed_tokens_for_multi_step_prefill(
                    seq_group, seq_group_metadata,
                    seq_group.state.num_steps == 1)
1208
            else:
1209
1210
1211
1212
                token_chunk_size = (seq_group_metadata.token_chunk_size
                                    if seq_group_metadata.token_chunk_size
                                    is not None else 0)
                seq_group.update_num_computed_tokens(token_chunk_size)
1213

1214
1215
1216
            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1217
                    " (i.e sampling_params.n == 1)")
1218
1219
1220
1221
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1222
1223
1224
1225
1226
1227
1228
1229
1230

                if self.scheduler_config.is_multi_step:
                    is_prefill_append = seq.data.get_num_uncomputed_tokens(
                    ) == 0
                    seq.append_token_id(sample.output_token, sample.logprobs)
                    if not is_prefill_append:
                        seq_group.update_num_computed_tokens(1)
                else:
                    seq.append_token_id(sample.output_token, sample.logprobs)
1231

1232
    def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1233
1234
        """Performs one decoding iteration and returns newly generated results.

1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

1250
            - Step 2: Calls the distributed executor to execute the model.
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
1272
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
1273
1274
1275
1276
1277
1278
1279
1280
1281
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
1282
        """
1283
1284
1285
1286
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1287

1288
        # For llm_engine, there is no pipeline parallel support, so the engine
1289
        # used is always 0.
1290
1291
        virtual_engine = 0

1292
1293
        # These are cached outputs from previous iterations. None if on first
        # iteration
1294
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1295
1296
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1297
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1298

1299
1300
        ctx = self.scheduler_contexts[virtual_engine]

1301
1302
1303
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1304
1305
1306
1307
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
1308
            # Schedule iteration
1309
            (seq_group_metadata_list, scheduler_outputs,
1310
1311
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1312

1313
1314
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1315

1316
1317
1318
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()

1319
1320
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1321
                self._process_model_outputs(ctx=ctx)
1322

1323
1324
1325
1326
1327
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
1328
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
1329
                    allow_async_output_proc)
1330
1331
        else:
            finished_requests_ids = list()
1332
1333
1334

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1335

1336
        if not scheduler_outputs.is_empty():
1337
1338
1339
1340
1341
1342

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1343
                self._get_last_sampled_token_ids(virtual_engine)
1344

1345
            execute_model_req = ExecuteModelRequest(
1346
1347
1348
1349
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1350
1351
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1352
1353
1354
1355
1356
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1357
            if allow_async_output_proc:
1358
1359
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1360

1361
            outputs = self.model_executor.execute_model(
1362
                execute_model_req=execute_model_req)
1363

1364
            # We need to do this here so that last step's sampled_token_ids can
1365
1366
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
1367
                self._update_cached_scheduler_output(virtual_engine, outputs)
1368
        else:
1369
1370
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1371
1372
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1373
            # No outputs in this case
1374
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1375

1376
1377
1378
1379
1380
1381
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
1382
            # clear the cache if we have finished all the steps.
1383
1384
1385
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[0] = SchedulerOutputState()

1386
1387
1388
1389
1390
1391
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1392
            # Add results to the output_queue
1393
1394
1395
1396
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1397
1398
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1399
1400
1401

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1402
                    "Async postprocessor expects only a single output set")
1403

1404
                self._advance_to_next_step(
1405
                    outputs[0], seq_group_metadata_list,
1406
                    scheduler_outputs.scheduled_seq_groups)
1407

1408
            # Check if need to run the usual non-async path
1409
            if not allow_async_output_proc:
1410
                self._process_model_outputs(ctx=ctx)
1411

1412
                # Log stats.
1413
                self.do_log_stats(scheduler_outputs, outputs)
1414

1415
1416
1417
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1418
            # Multi-step case
1419
            return ctx.request_outputs
1420

1421
        if not self.has_unfinished_requests():
1422
1423
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1424
                self._process_model_outputs(ctx=ctx)
1425
            assert len(ctx.output_queue) == 0
1426

1427
1428
1429
1430
1431
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1432
            logger.debug("Stopping remote worker execution loop.")
1433
1434
            self.model_executor.stop_remote_worker_execution_loop()

1435
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1436

1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
1452
1453
            raise AssertionError("All running sequence groups should "
                                 "have the same remaining steps.")
1454
1455
1456
1457
1458
1459

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1460
1461
1462
1463
1464
1465
1466
1467
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None

1493
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1494
1495
1496
1497
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1498
1499
1500
1501
1502
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1503
1504
1505
1506
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1507
1508
1509
1510
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1511
1512
1513
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1514
1515
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1516
1517
        """Forced log when no requests active."""
        if self.log_stats:
1518
            stats = self._get_stats(scheduler_outputs, model_output,
1519
                                    finished_before, skip)
1520
            for logger in self.stat_loggers.values():
1521
                logger.log(stats)
1522

1523
1524
1525
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1526
1527
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1528
1529
1530
1531
1532
1533
1534
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1535
1536
1537
1538
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1539
        """
1540
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1541

1542
1543
        # System State
        #   Scheduler State
1544
1545
1546
1547
1548
1549
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1550
1551

        # KV Cache Usage in %
1552
        num_total_gpu = self.cache_config.num_gpu_blocks
1553
        gpu_cache_usage_sys = 0.
1554
        if num_total_gpu:  # Guard against both None and 0
1555
1556
1557
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1558
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1559

1560
        num_total_cpu = self.cache_config.num_cpu_blocks
1561
        cpu_cache_usage_sys = 0.
1562
        if num_total_cpu:  # Guard against both None and 0
1563
1564
1565
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1566
1567
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1568
1569
1570
1571
1572
1573
1574
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1575
1576
1577
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1578
        num_tokens_iter = 0
1579
1580
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1581
1582
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1583
1584
1585
1586

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1587
1588
1589
1590
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1591
1592
1593
        time_in_queue_requests: List[float] = []
        model_forward_time_requests: List[float] = []
        model_execute_time_requests: List[float] = []
1594
1595
1596
1597
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1598
        max_num_generation_tokens_requests: List[int] = []
1599
        max_tokens_requests: List[int] = []
1600
1601
        finished_reason_requests: List[str] = []

1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
        # Lora requests
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1621
1622
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1623
        if scheduler_outputs is not None:
1624
1625
1626
1627
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1628
            num_generation_tokens_from_prefill_groups = 0
1629
1630
1631
1632
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1633
1634
1635

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1636
1637
1638
1639
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue
1640
1641
1642
1643
1644

                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1645

1646
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1647
                seq_group = scheduled_seq_group.seq_group
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
1660
                        latency = seq_group.get_last_token_latency()
1661
1662
1663
1664
1665
1666
1667
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
1668
                    latency = seq_group.get_last_token_latency()
1669
                    time_per_output_tokens_iter.append(latency)
1670
1671
1672
1673
1674
1675
1676
1677
1678
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1679
1680
1681
1682
1683
1684

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1685
                if seq_group.is_finished():
1686
                    # Latency timings
1687
1688
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1701
1702
1703
1704
1705
1706
1707
1708
1709
                    if seq_group.metrics.time_in_queue is not None:
                        time_in_queue_requests.append(
                            seq_group.metrics.time_in_queue)
                    if seq_group.metrics.model_forward_time is not None:
                        model_forward_time_requests.append(
                            seq_group.metrics.model_forward_time)
                    if seq_group.metrics.model_execute_time is not None:
                        model_execute_time_requests.append(
                            seq_group.metrics.model_execute_time * 1000)
1710
1711
1712
1713
1714
1715
1716
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1717
1718
1719
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1720
1721
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1722
1723
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1736
                actual_num_batched_tokens - num_prompt_tokens_iter +
1737
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1738
1739
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1740
1741
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
1742
1743
        if model_output and isinstance(model_output[0], SamplerOutput) and (
                model_output[0].spec_decode_worker_metrics is not None):
1744
1745
1746
1747
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

1748
1749
        return Stats(
            now=now,
1750
1751
1752
1753
1754
1755
1756
1757
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1758
1759
1760
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1761
1762
1763
1764

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1765
            num_tokens_iter=num_tokens_iter,
1766
1767
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1768
            spec_decode_metrics=spec_decode_metrics,
1769
            num_preemption_iter=num_preemption_iter,
1770
1771
1772
1773

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1774
1775
1776
1777
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1778
1779
1780
            time_in_queue_requests=time_in_queue_requests,
            model_forward_time_requests=model_forward_time_requests,
            model_execute_time_requests=model_execute_time_requests,
1781
1782
1783
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1784
1785
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
1786
            n_requests=n_requests,
1787
            max_tokens_requests=max_tokens_requests,
1788
            finished_reason_requests=finished_reason_requests,
1789
1790
1791
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1792

1793
    def add_lora(self, lora_request: LoRARequest) -> bool:
1794
        return self.model_executor.add_lora(lora_request)
1795
1796

    def remove_lora(self, lora_id: int) -> bool:
1797
        return self.model_executor.remove_lora(lora_id)
1798

1799
    def list_loras(self) -> Set[int]:
1800
        return self.model_executor.list_loras()
1801

1802
1803
1804
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_executor.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_executor.remove_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> List[int]:
        return self.model_executor.list_prompt_adapters()

1815
1816
1817
1818
1819
1820
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

1821
    def check_health(self) -> None:
1822
1823
        if self.tokenizer:
            self.tokenizer.check_health()
1824
        self.model_executor.check_health()
1825
1826
1827
1828

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

1829
1830
1831
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
1832
1833
1834
        if self.tracer is None:
            return

1835
1836
1837
1838
1839
1840
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
            ttft = metrics.first_token_time - metrics.arrival_time
            e2e_time = metrics.finished_time - metrics.arrival_time
1860
            seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
1861
                                   self.model_config.model)
1862
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
1863
                                   seq_group.request_id)
1864
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
1865
                                   seq_group.sampling_params.temperature)
1866
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
1867
                                   seq_group.sampling_params.top_p)
1868
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
1869
                                   seq_group.sampling_params.max_tokens)
1870
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
1871
                                   seq_group.sampling_params.n)
1872
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
1873
                                   seq_group.num_seqs())
1874
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
1875
1876
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
1877
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
1878
1879
1880
1881
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
1882
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
1883
1884
                                   metrics.time_in_queue)
            seq_span.set_attribute(
1885
1886
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
1887
1888
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
1889
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
1890
1891
1892
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
1893
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
1894
1895
1896
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
1897
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
1898
                    metrics.model_execute_time)
1899

1900
    def _validate_model_inputs(self, inputs: ProcessorInputs,
1901
                               lora_request: Optional[LoRARequest]):
1902
        if is_encoder_decoder_inputs(inputs):
1903
1904
            # For encoder-decoder multimodal models, the max_prompt_len
            # restricts the decoder prompt length
1905
1906
            prompt_inputs = inputs["decoder" if self.model_config.
                                   is_multimodal_model else "encoder"]
1907
        else:
1908
1909
            prompt_inputs = inputs

1910
        prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
1911
1912

        if prompt_ids is None or len(prompt_ids) == 0:
1913
            raise ValueError("Prompt cannot be empty")
1914

1915
        if self.model_config.is_multimodal_model:
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
            max_prompt_len = self.model_config.max_model_len

            if len(prompt_ids) > max_prompt_len:
                raise ValueError(
                    f"The prompt (total length {len(prompt_ids)}) is too long "
                    f"to fit into the model (context length {max_prompt_len}). "
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
1926
1927
1928
1929

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
        """Constructs logits processors based on the guided_decoding,
        logits_bias, and allowed_token_ids fields in sampling_params. Deletes
        those fields and adds the constructed logits processors to the
        logits_processors field. Returns the modified sampling params."""

        logits_processors = []
1940

1941
1942
1943
1944
1945
        if sampling_params.guided_decoding is not None:
            # Defensively copy sampling params since guided decoding logits
            # processors can have different state for each request
            sampling_params = copy.copy(sampling_params)
            guided_decoding = sampling_params.guided_decoding
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955

            logger.debug(
                "Building guided decoding logits processor in "
                "LLMEngine. Params: %s", guided_decoding)

            tokenizer = self.get_tokenizer(lora_request=lora_request)
            guided_decoding.backend = guided_decoding.backend or \
                self.decoding_config.guided_decoding_backend

            processor = get_local_guided_decoding_logits_processor(
1956
1957
1958
                guided_params=guided_decoding,
                tokenizer=tokenizer,
                model_config=self.model_config)
1959
1960
1961
1962
1963
1964
1965
1966
1967
            if processor:
                logits_processors.append(processor)

            # Unset so this doesn't get passed down to the model
            sampling_params.guided_decoding = None

        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

1968
            processors = get_openai_logits_processors(
1969
1970
1971
1972
1973
1974
1975
1976
1977
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

1978
1979
1980
1981
1982
1983
        if len(sampling_params.bad_words) > 0:
            tokenizer = self.get_tokenizer(lora_request)
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

1984
1985
1986
1987
1988
1989
1990
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params