llm_engine.py 81.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
import os
6
import copy
7

Antoni Baum's avatar
Antoni Baum committed
8
import time
9
from collections import Counter as collectionsCounter
10
from collections import deque
11
from contextlib import contextmanager
12
from dataclasses import dataclass
13
from functools import partial
14
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
15
                    Iterable, List, Literal, Mapping, NamedTuple, Optional)
16
from typing import Sequence as GenericSequence
17
from typing import Set, Type, Union, cast
18

19
import torch
20
from typing_extensions import TypeVar
21

22
import vllm.envs as envs
23
24
25
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ObservabilityConfig, ParallelConfig, SchedulerConfig,
                         VllmConfig)
26
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.engine.arg_utils import EngineArgs
28
from vllm.engine.metrics_types import StatLoggerBase, Stats
29
30
31
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
32
33
from vllm.entrypoints.openai.logits_processors import (
    get_logits_processors as get_openai_logits_processors)
34
from vllm.executor.executor_base import ExecutorBase
35
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
36
from vllm.inputs.parse import split_enc_dec_inputs
37
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
38
from vllm.logger import init_logger
39
from vllm.logits_process import get_bad_words_logits_processors
40
from vllm.lora.request import LoRARequest
41
from vllm.model_executor.layers.sampler import SamplerOutput
42
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
43
from vllm.multimodal.processing import EncDecMultiModalProcessor
44
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
45
46
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
47
from vllm.sampling_params import RequestOutputKind, SamplingParams
48
49
50
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
                           PoolingSequenceGroupOutput, Sequence, SequenceGroup,
                           SequenceGroupBase, SequenceGroupMetadata,
zhuwenwen's avatar
zhuwenwen committed
51
                           SequenceGroupOutput, SequenceStatus, CompletionSequenceGroupOutput, VLLM_INVALID_TOKEN_ID)
52
53
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                          init_tracer)
54
from vllm.transformers_utils.detokenizer import Detokenizer
55
from vllm.transformers_utils.tokenizer import AnyTokenizer
56
57
# DEBUG add cpm tokenizer
from vllm.transformers_utils.tokenizers import CPM9GTokenizer
58
from vllm.transformers_utils.tokenizer_group import (
59
    TokenizerGroup, init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
60
61
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
62
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
63
from vllm.version import __version__ as VLLM_VERSION
64
from vllm.worker.model_runner_base import InputProcessingError
65
from vllm.profiler.prof import profile
66
67

logger = init_logger(__name__)
68
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
69

70
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
71
_R = TypeVar("_R", default=Any)
72
73


74
75
76
77
78
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
79
80
    allow_async_output_proc: bool = False
    last_output: Optional[SamplerOutput] = None
81
82


83
84
85
86
87
88
class OutputData(NamedTuple):
    outputs: List[SamplerOutput]
    seq_group_metadata_list: List[SequenceGroupMetadata]
    scheduler_outputs: SchedulerOutputs
    is_async: bool
    is_last_step: bool
89
90
91
92
93
94
    # Indicates if this output is from the first step of the
    # multi-step. When multi-step is disabled, this is always
    # set to True.
    # is_first_step_output is invalid when `outputs` has
    # outputs from multiple steps.
    is_first_step_output: Optional[bool]
95
96
97
    skip: List[int]


98
class SchedulerContext:
99

100
    def __init__(self) -> None:
101
102
        self.output_queue: Deque[OutputData] = deque()
        self.request_outputs: List[Union[RequestOutput,
103
                                         PoolingRequestOutput]] = []
104
105
106
107
108
109
110
        self.seq_group_metadata_list: Optional[
            List[SequenceGroupMetadata]] = None
        self.scheduler_outputs: Optional[SchedulerOutputs] = None

    def append_output(self, outputs: List[SamplerOutput],
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      scheduler_outputs: SchedulerOutputs, is_async: bool,
111
112
                      is_last_step: bool,
                      is_first_step_output: Optional[bool]):
113
114
115
116
117
118
        self.output_queue.append(
            OutputData(outputs=outputs,
                       seq_group_metadata_list=seq_group_metadata_list,
                       scheduler_outputs=scheduler_outputs,
                       is_async=is_async,
                       is_last_step=is_last_step,
119
                       is_first_step_output=is_first_step_output,
120
                       skip=[]))
121
122


123
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
124
    """An LLM engine that receives requests and generates texts.
125

Woosuk Kwon's avatar
Woosuk Kwon committed
126
    This is the main class for the vLLM engine. It receives requests
127
128
129
130
131
132
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

133
134
135
    The [`LLM`][vllm.LLM] class wraps this class for offline batched inference
    and the [`AsyncLLMEngine`][vllm.engine.async_llm_engine.AsyncLLMEngine]
    class wraps this class for online serving.
136

137
    The config arguments are derived from [`EngineArgs`][vllm.EngineArgs].
138
139

    Args:
140
        vllm_config: The configuration for initializing and running vLLM.
141
142
        executor_class: The model executor class for managing distributed
            execution.
143
        log_stats: Whether to log statistics.
144
        usage_context: Specified entry point, used for usage info collection.
145
    """
146

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

172
        return cast(_O, output)
173
174
175

    @classmethod
    def validate_outputs(
zhuwenwen's avatar
zhuwenwen committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

194
195
        return outputs_

196
    tokenizer: Optional[TokenizerGroup]
197

198
199
    def __init__(
        self,
200
        vllm_config: VllmConfig,
201
        executor_class: Type[ExecutorBase],
202
        log_stats: bool,
yhu422's avatar
yhu422 committed
203
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
204
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
205
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
206
        use_cached_outputs: bool = False,
207
    ) -> None:
208
209
210
211
212
213
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
214

215
        self.vllm_config = vllm_config
216
217
218
219
220
221
222
223
224
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config  # noqa
        self.load_config = vllm_config.load_config
        self.decoding_config = vllm_config.decoding_config or DecodingConfig(  # noqa
225
        )
226
        self.observability_config = vllm_config.observability_config or ObservabilityConfig(  # noqa
227
228
        )

229
        logger.info(
230
            "Initializing a V0 LLM engine (v%s) with config: %s, "
231
            "use_cached_outputs=%s, ",
232
            VLLM_VERSION,
233
            vllm_config,
234
            use_cached_outputs,
235
        )
236

237
        self.log_stats = log_stats
238
        self.use_cached_outputs = use_cached_outputs
239

240
        if self.model_config.skip_tokenizer_init:
241
            self.tokenizer = None
242
            self.detokenizer = None
243
            tokenizer_group = None
244
245
246
247
        elif self.model_config.tokenizer_mode == "cpm":
            self.tokenizer = CPM9GTokenizer(self.model_config.model, trust_remote_code=True)
            self.detokenizer = Detokenizer(self.tokenizer, self.model_config.tokenizer_mode) 
            tokenizer_group = self.get_tokenizer_group()
248
        else:
249
250
251
            self.tokenizer = self._init_tokenizer()
            self.detokenizer = Detokenizer(self.tokenizer)
            tokenizer_group = self.get_tokenizer_group()
252
253
254
255
256
257
258

        # Ensure that the function doesn't contain a reference to self,
        # to avoid engine GC issues
        def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
            assert tokenizer_group, ("tokenizer_group cannot be None, "
                                     "make sure skip_tokenizer_init is False")
            return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
259

260
        self.seq_counter = Counter()
261
262
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
263

264
        self.input_preprocessor = InputPreprocessor(self.model_config,
265
266
                                                    self.tokenizer,
                                                    mm_registry)
267

268
        self.model_executor = executor_class(vllm_config=vllm_config)
269

270
        if self.model_config.runner_type != "pooling":
271
            self._initialize_kv_caches()
272

yhu422's avatar
yhu422 committed
273
274
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
275
276
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
277
            usage_message.report_usage(
278
                get_architecture_class_name(self.model_config),
yhu422's avatar
yhu422 committed
279
280
281
282
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
283
                    str(self.model_config.dtype),
yhu422's avatar
yhu422 committed
284
                    "tensor_parallel_size":
285
                    self.parallel_config.tensor_parallel_size,
yhu422's avatar
yhu422 committed
286
                    "block_size":
287
                    self.cache_config.block_size,
yhu422's avatar
yhu422 committed
288
                    "gpu_memory_utilization":
289
                    self.cache_config.gpu_memory_utilization,
yhu422's avatar
yhu422 committed
290
291
292

                    # Quantization
                    "quantization":
293
                    self.model_config.quantization,
yhu422's avatar
yhu422 committed
294
                    "kv_cache_dtype":
295
                    str(self.cache_config.cache_dtype),
yhu422's avatar
yhu422 committed
296
297
298

                    # Feature flags
                    "enable_lora":
299
                    bool(self.lora_config),
yhu422's avatar
yhu422 committed
300
                    "enable_prefix_caching":
301
                    self.cache_config.enable_prefix_caching,
yhu422's avatar
yhu422 committed
302
                    "enforce_eager":
303
                    self.model_config.enforce_eager,
yhu422's avatar
yhu422 committed
304
                    "disable_custom_all_reduce":
305
                    self.parallel_config.disable_custom_all_reduce,
yhu422's avatar
yhu422 committed
306
                })
zhuwenwen's avatar
zhuwenwen committed
307

308
309
310
311
312
313
        self.cached_scheduler_outputs = [
            SchedulerOutputState()
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

        self.scheduler_contexts = [
314
            SchedulerContext()
315
316
317
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]

318
        if self.model_config.use_async_output_proc:
319
            process_model_outputs = weak_bind(self._process_model_outputs)
zhuwenwen's avatar
zhuwenwen committed
320
321

            self.async_callbacks = [
322
323
                partial(process_model_outputs,
                        ctx=self.scheduler_contexts[v_id])
zhuwenwen's avatar
zhuwenwen committed
324
325
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
326
327
        else:
            self.async_callbacks = []
328
329
330

        # Currently used by AsyncLLMEngine to ensure quick append
        # of request outputs to asyncio queues
331
        self.process_request_outputs_callback: Optional[Callable] = None
332

333
        # Create the scheduler.
334
335
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
336
337
338
339
340
        if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
            Scheduler = resolve_obj_by_qualname(
                self.vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = self.vllm_config.scheduler_config.scheduler_cls
341
        self.scheduler = [
342
            Scheduler(
343
344
                self.scheduler_config, self.cache_config, self.lora_config,
                self.parallel_config.pipeline_parallel_size,
345
                self.async_callbacks[v_id]
346
347
                if self.model_config.use_async_output_proc else None)
            for v_id in range(self.parallel_config.pipeline_parallel_size)
348
        ]
zhuwenwen's avatar
zhuwenwen committed
349

350
351
        # Metric Logging.
        if self.log_stats:
352
353
354
            if stat_loggers is not None:
                self.stat_loggers = stat_loggers
            else:
355
356
357
358
359
360
361
                # Lazy import for prometheus multiprocessing.
                # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                # before prometheus_client is imported.
                # See https://prometheus.github.io/client_python/multiprocess/
                from vllm.engine.metrics import (LoggingStatLogger,
                                                 PrometheusStatLogger)

362
363
364
                self.stat_loggers = {
                    "logging":
                    LoggingStatLogger(
365
366
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                        vllm_config=vllm_config),
367
368
369
                    "prometheus":
                    PrometheusStatLogger(
                        local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
370
371
                        labels=dict(
                            model_name=self.model_config.served_model_name),
372
                        vllm_config=vllm_config),
373
374
375
                }
                self.stat_loggers["prometheus"].info("cache_config",
                                                     self.cache_config)
376

377
378
379
380
381
382
        self.tracer = None
        if self.observability_config.otlp_traces_endpoint:
            self.tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)

383
384
385
386
387
388
389
390
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
391
                get_tokenizer_for_seq,
392
393
                stop_checker=StopChecker(self.scheduler_config.max_model_len,
                                         get_tokenizer_for_seq),
394
            ))
395
396
        
        self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1'
397

398
399
        self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

400
401
402
        # Flag to set when an input fails to process and the engine should run
        # the next step without re-scheduling.
        self._skip_scheduling_next_step = False
403
        profile.StartTracer()
404

405
406
407
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

408
409
410
411
412
413
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
414
        start = time.time()
415
416
417
418
419
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
420
421
422
423
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
424
425
426
427
428
429
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
430
431
432
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
433

434
    @classmethod
435
    def _get_executor_cls(cls,
436
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
437
        # distributed_executor_backend must be set in VllmConfig.__post_init__
438
439
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
440
        # Initialize the cluster and specify the executor class.
441
442
443
444
445
446
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
447
448
449
450
451
452
453
454
455
456
457
458
459
        elif distributed_executor_backend == "ray":
            from vllm.executor.ray_distributed_executor import (
                RayDistributedExecutor)
            executor_class = RayDistributedExecutor
        elif distributed_executor_backend == "mp":
            from vllm.executor.mp_distributed_executor import (
                MultiprocessingDistributedExecutor)
            assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                "multiprocessing distributed executor backend does not "
                "support VLLM_USE_RAY_SPMD_WORKER=1")
            executor_class = MultiprocessingDistributedExecutor
        elif distributed_executor_backend == "uni":
            # JAX-style, single-process, multi-device executor.
460
461
            from vllm.executor.uniproc_executor import UniProcExecutor
            executor_class = UniProcExecutor
462
463
464
465
466
467
468
469
        elif distributed_executor_backend == "external_launcher":
            # executor with external launcher
            from vllm.executor.uniproc_executor import (  # noqa
                ExecutorWithExternalLauncher)
            executor_class = ExecutorWithExternalLauncher
        else:
            raise ValueError("unrecognized distributed_executor_backend: "
                             f"{distributed_executor_backend}")
470
471
        return executor_class

472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

488
489
490
491
492
493
494
495
496
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
497
498
499
500
501
502
503
504
505
        vllm_config = engine_args.create_engine_config(usage_context)

        engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            engine_cls = V1LLMEngine

        return engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
506
            usage_context=usage_context,
507
            stat_loggers=stat_loggers,
508
            disable_log_stats=engine_args.disable_log_stats,
yhu422's avatar
yhu422 committed
509
        )
510

511
512
513
514
515
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

516
517
518
519
520
521
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

522
523
    def get_tokenizer_group(self) -> TokenizerGroup:
        if self.tokenizer is None:
524
525
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
526

527
        return self.tokenizer
528

529
    def get_tokenizer(
530
531
532
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
533
534
535
536
        if self.model_config.tokenizer_mode == "cpm":
            return self.tokenizer
        else:
            return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
537

538
    def _init_tokenizer(self) -> TokenizerGroup:
539
540
541
        return init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=self.scheduler_config,
542
            lora_config=self.lora_config)
543

544
545
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
546
        self.cache_config.verify_with_parallel_config(self.parallel_config)
547
548
549
550
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
551

552
553
554
    def _add_processed_request(
        self,
        request_id: str,
555
        processed_inputs: ProcessorInputs,
556
557
558
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
559
        trace_headers: Optional[Mapping[str, str]] = None,
560
        priority: int = 0,
561
    ) -> Optional[SequenceGroup]:
562
563
564
        """Add a processed request to the engine's request pool.
        return the created sequence group.
        """
565
566
567
568
569
570
571
572
573
574
575
576
577
        if isinstance(params, SamplingParams) and params.n > 1:
            ParallelSampleSequenceGroup.add_request(
                request_id,
                self,
                params,
                processed_inputs=processed_inputs,
                arrival_time=arrival_time,
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
            )
            return None

578
        self._validate_model_inputs(processed_inputs, lora_request)
579
580
581
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
582
583
584
585
586
        #DEBUG @TODO change tokenizer false
        if self.model_config.tokenizer_mode == "cpm":
            eos_token_id = self.tokenizer.eos_id
        else:
            eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
587

588
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
589
590

        seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
591
                       lora_request)
592

593
        encoder_seq = (None if encoder_inputs is None else Sequence(
594
            seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
595

596
597
598
599
600
601
602
603
        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
604
                trace_headers=trace_headers,
605
606
                encoder_seq=encoder_seq,
                priority=priority)
607
608
609
610
611
612
613
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
614
615
                encoder_seq=encoder_seq,
                priority=priority)
616
617
618
619
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

620
621
622
623
624
625
626
627
        # Add the sequence group to the scheduler with least unfinished seqs.
        costs = [
            scheduler.get_num_unfinished_seq_groups()
            for scheduler in self.scheduler
        ]
        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
        min_cost_scheduler.add_seq_group(seq_group)

628
629
        return seq_group

630
631
    def stop_remote_worker_execution_loop(self) -> None:
        self.model_executor.stop_remote_worker_execution_loop()
632

633
634
635
    def add_request(
        self,
        request_id: str,
636
        prompt: PromptType,
637
        params: Union[SamplingParams, PoolingParams],
638
        arrival_time: Optional[float] = None,
639
        lora_request: Optional[LoRARequest] = None,
640
        tokenization_kwargs: Optional[dict[str, Any]] = None,
641
        trace_headers: Optional[Mapping[str, str]] = None,
642
        priority: int = 0,
643
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
644
        """Add a request to the engine's request pool.
645
646

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
647
        scheduler as `engine.step()` is called. The exact scheduling policy is
648
649
650
651
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
652
653
            prompt: The prompt to the LLM. See
                [PromptType][vllm.inputs.PromptType]
654
655
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
656
657
                [SamplingParams][vllm.SamplingParams] for text generation.
                [PoolingParams][vllm.PoolingParams] for pooling.
658
            arrival_time: The arrival time of the request. If None, we use
659
                the current monotonic time.
660
            lora_request: The LoRA request to add.
661
            trace_headers: OpenTelemetry trace headers.
662
663
            priority: The priority of the request.
                Only applicable with priority scheduling.
664
665
666
667

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
668
669
670
671
672
            - Create `n` number of [Sequence][vllm.Sequence] objects.
            - Create a [SequenceGroup][vllm.SequenceGroup] object
              from the list of [Sequence][vllm.Sequence].
            - Add the [SequenceGroup][vllm.SequenceGroup] object to the
              scheduler.
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
689
        """
690
691
692
693
        if not isinstance(request_id, str):
            raise TypeError(
                f"request_id must be a string, got {type(request_id)}")

694
695
696
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
697

698
        if priority != 0 and not self.scheduler_config.policy == "priority":
699
700
701
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

702
        if isinstance(params, SamplingParams) \
703
            and params.logits_processors:
704
            raise ValueError(
705
                "Logits processors are not supported in multi-step decoding")
706

707
        if arrival_time is None:
708
            arrival_time = time.time()
709

710
711
712
713
714
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            seq_len = prompt["prompt_embeds"].shape[0]
            prompt["prompt_token_ids"] = [0] * seq_len
715

716
717
718
719
        #DEBUG anrongqiao
        if self.model_config.tokenizer_mode == "cpm":
            lora_request = None
            
720
        processed_inputs = self.input_preprocessor.preprocess(
721
            prompt,
722
            tokenization_kwargs=tokenization_kwargs,
723
            lora_request=lora_request,
724
        )
725

726
727
728
729
730
731
        self._add_processed_request(
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
732
            trace_headers=trace_headers,
733
            priority=priority,
734
        )
735
736
737
738
739
740

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
741
742
        arrival_time: float,
        lora_request: Optional[LoRARequest],
743
        trace_headers: Optional[Mapping[str, str]] = None,
744
        encoder_seq: Optional[Sequence] = None,
745
        priority: int = 0,
746
747
748
749
750
751
752
753
754
755
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

756
757
758
        sampling_params = self._build_logits_processors(
            sampling_params, lora_request)

759
760
761
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
762

763
        sampling_params.update_from_generation_config(
764
            self.generation_config_fields, seq.eos_token_id)
765

766
        # Create the sequence group.
767
768
769
770
        draft_size = 1
        if self.vllm_config.speculative_config is not None:
            draft_size = \
                self.vllm_config.speculative_config.num_speculative_tokens + 1
771
772
773
774
775
776
777
778
779
        seq_group = SequenceGroup(request_id=request_id,
                                  seqs=[seq],
                                  arrival_time=arrival_time,
                                  sampling_params=sampling_params,
                                  lora_request=lora_request,
                                  trace_headers=trace_headers,
                                  encoder_seq=encoder_seq,
                                  priority=priority,
                                  draft_size=draft_size)
780

781
782
783
784
785
786
787
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
788
789
        arrival_time: float,
        lora_request: Optional[LoRARequest],
790
        encoder_seq: Optional[Sequence] = None,
791
        priority: int = 0,
792
793
794
795
796
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
797
798
799
800
801
802
803
        seq_group = SequenceGroup(request_id=request_id,
                                  seqs=[seq],
                                  arrival_time=arrival_time,
                                  lora_request=lora_request,
                                  pooling_params=pooling_params,
                                  encoder_seq=encoder_seq,
                                  priority=priority)
804
        return seq_group
805

Antoni Baum's avatar
Antoni Baum committed
806
807
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
808
809

        Args:
Antoni Baum's avatar
Antoni Baum committed
810
            request_id: The ID(s) of the request to abort.
811
812

        Details:
813
            - Refer to [vllm.core.scheduler.Scheduler.abort_seq_group][].
814
815
816
817
818
819

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
820
        """
821
        for scheduler in self.scheduler:
822
823
            scheduler.abort_seq_group(
                request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
824

825
826
827
828
    def get_vllm_config(self) -> VllmConfig:
        """Gets the vllm configuration."""
        return self.vllm_config

829
830
831
832
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

833
834
835
836
    def get_parallel_config(self) -> ParallelConfig:
        """Gets the parallel configuration."""
        return self.parallel_config

837
838
839
840
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

841
842
843
844
845
846
847
848
    def get_scheduler_config(self) -> SchedulerConfig:
        """Gets the scheduler configuration."""
        return self.scheduler_config

    def get_lora_config(self) -> LoRAConfig:
        """Gets the LoRA configuration."""
        return self.lora_config

849
    def get_num_unfinished_requests(self) -> int:
850
        """Gets the number of unfinished requests."""
851
852
        return sum(scheduler.get_num_unfinished_seq_groups()
                   for scheduler in self.scheduler)
853

854
    def has_unfinished_requests(self) -> bool:
855
        """Returns True if there are unfinished requests."""
856
857
858
859
860
861
862
863
864
        return any(scheduler.has_unfinished_seqs()
                   for scheduler in self.scheduler)

    def has_unfinished_requests_for_virtual_engine(
            self, virtual_engine: int) -> bool:
        """
        Returns True if there are unfinished requests for the virtual engine.
        """
        return self.scheduler[virtual_engine].has_unfinished_seqs()
865

866
867
    def reset_mm_cache(self) -> bool:
        """Reset the multi-modal cache."""
868
869
        return self.input_preprocessor.mm_registry.reset_processor_cache(
            self.model_config)
870

871
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
872
873
874
875
        """Reset prefix cache for all devices."""

        success = True
        for scheduler in self.scheduler:
876
            success = success and scheduler.reset_prefix_cache(device)
877
878
        return success

879
    @staticmethod
880
881
    def _process_sequence_group_outputs(
        seq_group: SequenceGroup,
882
        outputs: List[PoolingSequenceGroupOutput],
883
    ) -> None:
884
        seq_group.pooled_data = outputs[0].data
885
886
887
888
889
890

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

891
892
893
894
895
    def _process_model_outputs(self,
                               ctx: SchedulerContext,
                               request_id: Optional[str] = None) -> None:
        """Apply the model output to the sequences in the scheduled seq groups
        and return responses.
896

897
898
        ctx: The virtual engine context to work on
        request_id: If provided, then only this request is going to be processed
899
        """
900

901
        now = time.time()
902

903
        if len(ctx.output_queue) == 0:
904
905
            return None

906
        # Get pending async postprocessor
907
908
909
910
        if request_id:
            # When we process only one request, no pop is required
            # (since later we will process all of the rest)
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
911
             is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
912
913
        else:
            (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
914
915
             is_last_step, is_first_step_output,
             skip) = ctx.output_queue.popleft()
916
917
918
919
920

        # Sanity check
        assert len(seq_group_metadata_list) == len(
            scheduler_outputs.scheduled_seq_groups)

921
        has_multiple_outputs: bool = len(outputs) > 1
922
        outputs_by_sequence_group: List[List[SequenceGroupOutput]]
923
924
        assert not has_multiple_outputs
        outputs_by_sequence_group = outputs
925

926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
        # Determine the requests we need to operate on
        if request_id:
            indices = []
            for i, seq_group_meta in enumerate(seq_group_metadata_list):
                if seq_group_meta.request_id == request_id:
                    assert i not in skip  # Cannot be called twice
                    indices.append(i)
                    break

            # If the request_id was not found, then it means that
            # this is a new request that has no pending async
            # postprocessor
            if not indices:
                return
        else:
            indices = range(len(seq_group_metadata_list))  # type: ignore

943
        finished_before: List[int] = []
944
        finished_now: List[int] = []
945
        empty_seq_indices: List[int] = []
946
947
948
949
950
        for i in indices:
            if i in skip:
                continue

            seq_group_meta = seq_group_metadata_list[i]
951
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
952

953
            seq_group: SequenceGroup = scheduled_seq_group.seq_group
954
955
956
957
958

            if seq_group.is_finished():
                finished_before.append(i)
                continue

959
            output: List[SequenceGroupOutput]
960
            if has_multiple_outputs:
961
962
963
964
                output = outputs_by_sequence_group[i]
            else:
                output = [outputs_by_sequence_group[0][i]]

965
            # tree style speculative decoding may generate empty output in first step
966
967
968
969
970
971
972
973
974
            if self.tree_decoding and outputs and isinstance(output[0], CompletionSequenceGroupOutput):
                samples = [o.samples[0] for o in output]
                valid_samples = [
                    sample for sample in samples
                    if sample.output_token != VLLM_INVALID_TOKEN_ID
                ]
                if len(valid_samples) == 0:
                    empty_seq_indices.append(i)
                    continue
975

976
            if not is_async:
977
978
                seq_group.update_num_computed_tokens(
                    seq_group_meta.token_chunk_size or 0)
979
980
981

            if outputs:
                for o in outputs:
982
983
984
985
                    if (isinstance(o, SamplerOutput)
                            and seq_group.metrics is not None):
                        if seq_group.metrics.model_forward_time is not None:
                            seq_group.metrics.model_forward_time += (
986
                                o.model_forward_time or 0)
987
988
989
990
991
                        else:
                            seq_group.metrics.model_forward_time = (
                                o.model_forward_time)
                        if seq_group.metrics.model_execute_time is not None:
                            seq_group.metrics.model_execute_time += (
992
                                o.model_execute_time or 0)
993
994
995
                        else:
                            seq_group.metrics.model_execute_time = (
                                o.model_execute_time)
996

997
            if self.model_config.runner_type == "pooling":
998
                self._process_sequence_group_outputs(seq_group, output)
999
1000
1001
1002
1003
            else:
                self.output_processor.process_prompt_logprob(seq_group, output)
                if seq_group_meta.do_sample:
                    self.output_processor.process_outputs(
                        seq_group, output, is_async)
1004

1005
1006
            if seq_group.is_finished():
                finished_now.append(i)
1007

1008
1009
1010
        # Generate outputs for the requests that finished this iteration
        for i in finished_now:
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1011

1012
            seq_group = scheduled_seq_group.seq_group
1013
            seq_group.maybe_set_first_token_time(now)
1014
1015
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1016
            request_output = RequestOutputFactory.create(
1017
1018
1019
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1020
1021
            if request_output:
                ctx.request_outputs.append(request_output)
1022

1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        # When we process a single request, we skip it for the next time,
        # and invoke the request output callback (if there was final output)
        if request_id:
            assert len(indices) == 1
            skip.append(indices[0])

            if (finished_now
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
            return

1035
1036
1037
1038
1039
1040
        # Free currently finished requests
        if finished_now:
            for scheduler in self.scheduler:
                scheduler.free_finished_seq_groups()

        # Create the outputs
1041
        for i in indices:
1042
            if i in skip or i in finished_before or i in finished_now or i in empty_seq_indices:
1043
1044
                continue  # Avoids double processing

1045
1046
            scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

1047
            seq_group = scheduled_seq_group.seq_group
1048
            seq_group.maybe_set_first_token_time(now)
1049
1050
            if not seq_group.is_prefill():
                seq_group.set_last_token_time(now)
1051
            request_output = RequestOutputFactory.create(
1052
1053
1054
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs)
1055
            if request_output:
1056
                ctx.request_outputs.append(request_output)
1057

1058
        # Create outputs only after processing the scheduler's results
1059

1060
        for seq_group in scheduler_outputs.ignored_seq_groups:
1061
1062
1063
1064
1065
            params = seq_group.sampling_params
            if params is not None and params.output_kind == (
                    RequestOutputKind.DELTA) and not seq_group.is_finished():
                continue

1066
            request_output = RequestOutputFactory.create(
1067
1068
1069
1070
                seq_group,
                self.seq_id_to_seq_group,
                use_cache=self.use_cached_outputs,
            )
1071
1072
            if request_output:
                ctx.request_outputs.append(request_output)
1073

1074
1075
1076
1077
        # Immediately process request outputs here (if callback is given)
        if (ctx.request_outputs
                and self.process_request_outputs_callback is not None):
            self.process_request_outputs_callback(ctx.request_outputs)
1078
            ctx.request_outputs.clear()
1079

1080
1081
1082
1083
        # For async case, we need to record the stats here.
        # For non-async case, the stats are done in the
        # LLMEngine/AsyncLLMEngine directly
        if is_async:
1084
            # Log stats.
1085
1086
            self.do_log_stats(scheduler_outputs, outputs, finished_before,
                              skip)
1087
1088

            # Tracing
1089
            self.do_tracing(scheduler_outputs, finished_before)
1090
1091
1092
1093

        return None

    def _advance_to_next_step(
1094
            self, output: SamplerOutput,
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done inside output processor, but it is
        required if the worker is to perform async forward pass to next step.
        """
        for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
            zip(seq_group_metadata_list, output, scheduled_seq_groups):
            seq_group = scheduled_seq_group.seq_group

            if seq_group.is_finished():
                continue

1108
1109
1110
1111
            token_chunk_size = (seq_group_metadata.token_chunk_size
                                if seq_group_metadata.token_chunk_size
                                is not None else 0)
            seq_group.update_num_computed_tokens(token_chunk_size)
1112
1113
1114
1115

            if seq_group_metadata.do_sample:
                assert len(sequence_group_outputs.samples) == 1, (
                    "Async output processor expects a single sample"
1116
                    " (i.e sampling_params.n == 1)")
1117
1118
1119
1120
                sample = sequence_group_outputs.samples[0]

                assert len(seq_group.seqs) == 1
                seq = seq_group.seqs[0]
1121

1122
1123
                seq.append_token_id(sample.output_token, sample.logprobs,
                                    sample.output_embed)
1124

1125
    def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
1126
1127
        """Performs one decoding iteration and returns newly generated results.

1128
1129
1130
1131
        <figure markdown="span">
        ![Overview of the step function](https://i.imgur.com/sv2HssD.png)
        <figcaption>Overview of the step function</figcaption>
        </figure>
1132
1133

        Details:
1134
1135
        - Step 1: Schedules the sequences to be executed in the next
            iteration and the token blocks to be swapped in/out/copy.
1136

1137
1138
1139
1140
            - Depending on the scheduling policy,
                sequences may be `preempted/reordered`.
            - A Sequence Group (SG) refer to a group of sequences
                that are generated from the same prompt.
1141

1142
1143
        - Step 2: Calls the distributed executor to execute the model.
        - Step 3: Processes the model output. This mainly includes:
1144

1145
1146
1147
1148
            - Decodes the relevant outputs.
            - Updates the scheduled sequence groups with model outputs
                based on its `sampling parameters` (`use_beam_search` or not).
            - Frees the finished sequence groups.
1149

1150
        - Finally, it creates and returns the newly generated results.
1151
1152

        Example:
1153
1154
1155
1156
1157
1158
1159
        ```
        # Please see the example/ folder for more detailed examples.

        # initialize engine and request arguments
        engine = LLMEngine.from_engine_args(engine_args)
        example_inputs = [(0, "What is LLM?",
        SamplingParams(temperature=0.0))]
1160

1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
        # Start the engine with an event loop
        while True:
            if example_inputs:
                req_id, prompt, sampling_params = example_inputs.pop(0)
                engine.add_request(str(req_id),prompt,sampling_params)

            # continue the request processing
            request_outputs = engine.step()
            for request_output in request_outputs:
                if request_output.finished:
                    # return or show the request output

            if not (engine.has_unfinished_requests() or example_inputs):
                break
        ```
Antoni Baum's avatar
Antoni Baum committed
1176
        """
1177
1178
1179
1180
        if self.parallel_config.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported through AsyncLLMEngine "
                "as performance will be severely degraded otherwise.")
1181

1182
        # For llm_engine, there is no pipeline parallel support, so the engine
1183
        # used is always 0.
1184
1185
        virtual_engine = 0

1186
1187
        # These are cached outputs from previous iterations. None if on first
        # iteration
1188
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1189
1190
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
1191
        allow_async_output_proc = cached_outputs.allow_async_output_proc
1192

1193
1194
        ctx = self.scheduler_contexts[virtual_engine]

1195
1196
1197
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

1198
1199
1200
        # Skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
1201
1202
1203
1204
1205
        # The scheduler is also skipped if a single request caused the last
        # engine step to fail, and the previous schedule needs to be rerun.
        if not self._has_remaining_steps(
                seq_group_metadata_list
        ) and not self._skip_scheduling_next_step:
1206
            # Schedule iteration
1207
            (seq_group_metadata_list, scheduler_outputs,
1208
1209
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()
1210

1211
1212
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
1213

1214
1215
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
1216
1217
1218
1219
1220
            # When n>1, elements in self.seq_id_to_seq_group should be deleted
            # here, otherwise memory leaks.
            for finished_request_id in finished_requests_ids:
                if finished_request_id in self.seq_id_to_seq_group:
                    del self.seq_id_to_seq_group[finished_request_id]
1221

1222
1223
            # Maybe switch from async mode to sync mode
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
1224
                self._process_model_outputs(ctx=ctx)
1225

1226
1227
        else:
            finished_requests_ids = list()
1228
1229
1230

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
1231

1232
        if not scheduler_outputs.is_empty():
1233
1234
1235
1236
1237
1238

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
1239
                self._get_last_sampled_token_ids(virtual_engine)
1240

1241
            execute_model_req = ExecuteModelRequest(
1242
1243
1244
1245
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
1246
1247
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
1248
1249
1250
1251
1252
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)

1253
            if allow_async_output_proc:
1254
1255
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
1256

1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
            try:
                outputs = self.model_executor.execute_model(
                    execute_model_req=execute_model_req)
                self._skip_scheduling_next_step = False
            except InputProcessingError as e:
                # The input for this request cannot be processed, so we must
                # abort it. If there are remaining requests in the batch that
                # have been scheduled, they will be retried on the next step.
                invalid_request_id = e.request_id
                self._abort_and_cache_schedule(
                    request_id=invalid_request_id,
                    virtual_engine=virtual_engine,
                    seq_group_metadata_list=seq_group_metadata_list,
                    scheduler_outputs=scheduler_outputs,
                    allow_async_output_proc=allow_async_output_proc)
                # Raise so the caller is notified that this request failed
                raise
1274

1275
        else:
1276
1277
            # Nothing scheduled => If there is pending async postprocessor,
            # then finish it here.
1278
1279
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
1280
            # No outputs in this case
1281
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
1282

1283
        if not self._has_remaining_steps(seq_group_metadata_list):
1284
            # is_first_step_output is True only when the num_steps of all
1285
            # the sequences are 1.
1286
1287
1288
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

1289
            # Add results to the output_queue
1290
1291
1292
1293
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
1294
1295
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
1296
1297
1298

            if outputs and allow_async_output_proc:
                assert len(outputs) == 1, (
1299
                    "Async postprocessor expects only a single output set")
1300

1301
                self._advance_to_next_step(
1302
                    outputs[0], seq_group_metadata_list,
1303
                    scheduler_outputs.scheduled_seq_groups)
1304

1305
            # Check if need to run the usual non-async path
1306
            if not allow_async_output_proc:
1307
                self._process_model_outputs(ctx=ctx)
1308

1309
                # Log stats.
1310
                self.do_log_stats(scheduler_outputs, outputs)
1311

1312
1313
1314
                # Tracing
                self.do_tracing(scheduler_outputs)
        else:
1315
            # Multi-step case
1316
            return ctx.request_outputs
1317

1318
        if not self.has_unfinished_requests():
1319
1320
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
1321
                self._process_model_outputs(ctx=ctx)
1322
            assert len(ctx.output_queue) == 0
1323

1324
1325
1326
1327
1328
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
1329
            logger.debug("Stopping remote worker execution loop.")
1330
1331
            self.model_executor.stop_remote_worker_execution_loop()

1332
        return ctx.request_outputs
Antoni Baum's avatar
Antoni Baum committed
1333

1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
    def _abort_and_cache_schedule(
            self, request_id: str, virtual_engine: int,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        """Aborts a single request, and caches the scheduler outputs minus that
        request. This allows the next step to continue processing the remaining
        requests without having to re-run the scheduler."""

        # Abort the request and remove its sequence group from the current
        # schedule
        self.abort_request(request_id)
        for i, metadata in enumerate(seq_group_metadata_list):
            if metadata.request_id == request_id:
                del seq_group_metadata_list[i]
                break
        for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
            if group.seq_group.request_id == request_id:
                del scheduler_outputs.scheduled_seq_groups[i]
                break

        # If there are still other sequence groups left in the schedule, cache
        # them and flag the engine to reuse the schedule.
        if len(seq_group_metadata_list) > 0:
            self._skip_scheduling_next_step = True
            # Reuse multi-step caching logic
            self._cache_scheduler_outputs_for_multi_step(
                virtual_engine=virtual_engine,
                scheduler_outputs=scheduler_outputs,
                seq_group_metadata_list=seq_group_metadata_list,
                allow_async_output_proc=allow_async_output_proc)

1366
1367
1368
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
1369
        return False
1370
1371
1372
1373

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1374
1375
1376
1377
1378
1379
1380
1381
            scheduler_outputs: SchedulerOutputs,
            allow_async_output_proc: bool) -> None:
        co = self.cached_scheduler_outputs[virtual_engine]

        co.seq_group_metadata_list = seq_group_metadata_list
        co.scheduler_outputs = scheduler_outputs
        co.allow_async_output_proc = allow_async_output_proc
        co.last_output = None
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        return None
Antoni Baum's avatar
Antoni Baum committed
1399

1400
    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1401
1402
1403
1404
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1405
1406
1407
1408
1409
        if logger_name in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} already exists.")
        self.stat_loggers[logger_name] = logger

    def remove_logger(self, logger_name: str) -> None:
1410
1411
1412
1413
        if not self.log_stats:
            raise RuntimeError(
                "Stat logging is disabled. Set `disable_log_stats=False` "
                "argument to enable.")
1414
1415
1416
1417
        if logger_name not in self.stat_loggers:
            raise KeyError(f"Logger with name {logger_name} does not exist.")
        del self.stat_loggers[logger_name]

1418
1419
1420
    def do_log_stats(self,
                     scheduler_outputs: Optional[SchedulerOutputs] = None,
                     model_output: Optional[List[SamplerOutput]] = None,
1421
1422
                     finished_before: Optional[List[int]] = None,
                     skip: Optional[List[int]] = None) -> None:
1423
1424
        """Forced log when no requests active."""
        if self.log_stats:
1425
            stats = self._get_stats(scheduler_outputs, model_output,
1426
                                    finished_before, skip)
1427
            for logger in self.stat_loggers.values():
1428
                logger.log(stats)
1429

1430
1431
1432
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs],
                   model_output: Optional[List[SamplerOutput]] = None,
1433
1434
                   finished_before: Optional[List[int]] = None,
                   skip: Optional[List[int]] = None) -> Stats:
1435
1436
1437
1438
1439
1440
1441
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
1442
1443
1444
1445
            finished_before: Optional, indices of sequences that were finished
                before. These sequences will be ignored.
            skip: Optional, indices of sequences that were preempted. These
                sequences will be ignored.
1446
        """
1447
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
1448

1449
1450
        # System State
        #   Scheduler State
1451
1452
1453
1454
1455
1456
        num_running_sys = sum(
            len(scheduler.running) for scheduler in self.scheduler)
        num_swapped_sys = sum(
            len(scheduler.swapped) for scheduler in self.scheduler)
        num_waiting_sys = sum(
            len(scheduler.waiting) for scheduler in self.scheduler)
1457
1458

        # KV Cache Usage in %
1459
        num_total_gpu = self.cache_config.num_gpu_blocks
1460
        gpu_cache_usage_sys = 0.
1461
        if num_total_gpu:  # Guard against both None and 0
1462
1463
1464
            num_free_gpu = sum(
                scheduler.block_manager.get_num_free_gpu_blocks()
                for scheduler in self.scheduler)
1465
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
1466

1467
        num_total_cpu = self.cache_config.num_cpu_blocks
1468
        cpu_cache_usage_sys = 0.
1469
        if num_total_cpu:  # Guard against both None and 0
1470
1471
1472
            num_free_cpu = sum(
                scheduler.block_manager.get_num_free_cpu_blocks()
                for scheduler in self.scheduler)
1473
1474
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

1475
1476
1477
1478
1479
1480
1481
        # Prefix Cache Hit Rate. Note that we always use
        # the cache hit rate of the first virtual engine.
        cpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.CPU)
        gpu_prefix_cache_hit_rate = self.scheduler[
            0].get_prefix_cache_hit_rate(Device.GPU)

1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
        # Exchange the uasge and cache hit stats between gpu and cpu when
        # running on cpu because the cpu_worker.py intentionally reports the
        # number of cpu blocks as gpu blocks in favor of cache management.
        if self.device_config.device_type == "cpu":
            num_total_gpu, num_total_cpu = num_total_cpu, num_total_gpu
            gpu_cache_usage_sys, cpu_cache_usage_sys = (
                cpu_cache_usage_sys,
                gpu_cache_usage_sys,
            )
            gpu_prefix_cache_hit_rate, cpu_prefix_cache_hit_rate = (
                cpu_prefix_cache_hit_rate,
                gpu_prefix_cache_hit_rate,
            )

1496
1497
1498
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
harrywu's avatar
harrywu committed
1499
        num_tokens_iter = 0
1500
1501
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
1502
1503
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
1504
1505
1506
1507

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
harrywu's avatar
harrywu committed
1508
1509
1510
1511
        time_queue_requests: List[float] = []
        time_inference_requests: List[float] = []
        time_prefill_requests: List[float] = []
        time_decode_requests: List[float] = []
1512
1513
1514
1515
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        n_requests: List[int] = []
harrywu's avatar
harrywu committed
1516
        max_num_generation_tokens_requests: List[int] = []
1517
        max_tokens_requests: List[int] = []
1518
1519
        finished_reason_requests: List[str] = []

1520
        # LoRA requests
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
        running_lora_adapters = dict(
            collectionsCounter([
                running_request.lora_request.lora_name
                for scheduler in self.scheduler
                for running_request in scheduler.running
                if running_request.lora_request
            ]))
        waiting_lora_adapters = dict(
            collectionsCounter([
                waiting_request.lora_request.lora_name
                for scheduler in self.scheduler
                for waiting_request in scheduler.waiting
                if waiting_request.lora_request
            ]))
        max_lora_stat = "0"
        if self.lora_config:
            max_lora_stat = str(self.lora_config.max_loras)

1539
1540
        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
1541
        if scheduler_outputs is not None:
1542
1543
1544
1545
            # For async postprocessor, already finished sequences need to be
            # not counted (to avoid double counting)
            actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore

1546
            num_generation_tokens_from_prefill_groups = 0
1547
1548
1549
1550
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
1551
1552
1553

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
1554
1555
1556
1557
1558
                # Skip double logging when using async output proc
                if finished_before and idx in finished_before:
                    actual_num_batched_tokens -= 1
                    continue

1559
1560
1561
1562
                # Currently, skip == preempted sequences, so we need to skip
                # their log stats
                if skip and idx in skip:
                    continue
1563

1564
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1565
                seq_group = scheduled_seq_group.seq_group
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
1578
                        latency = seq_group.get_last_token_latency()
1579
1580
1581
1582
1583
1584
1585
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
1586
                    latency = seq_group.get_last_token_latency()
1587
                    time_per_output_tokens_iter.append(latency)
1588
1589
1590
1591
1592
1593
1594
1595
1596
                    if seq_group.state.current_step == 0:
                        # For async_output_proc, the do_log_stats()
                        # is called following init_multi_step(), which
                        # sets the current_step to zero.
                        actual_num_batched_tokens +=\
                            seq_group.state.num_steps - 1
                    else:
                        actual_num_batched_tokens +=\
                            seq_group.state.current_step - 1
1597
1598
1599
1600
1601
1602

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
1603
                if seq_group.is_finished():
1604
                    # Latency timings
1605
1606
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
harrywu's avatar
harrywu committed
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
                    if (seq_group.metrics.first_scheduled_time is not None and
                            seq_group.metrics.first_token_time is not None):
                        time_queue_requests.append(
                            seq_group.metrics.first_scheduled_time -
                            seq_group.metrics.arrival_time)
                        time_prefill_requests.append(
                            seq_group.metrics.first_token_time -
                            seq_group.metrics.first_scheduled_time)
                        time_decode_requests.append(
                            now - seq_group.metrics.first_token_time)
                        time_inference_requests.append(
                            now - seq_group.metrics.first_scheduled_time)
1619
1620
1621
1622
1623
1624
1625
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
harrywu's avatar
harrywu committed
1626
1627
1628
                    max_num_generation_tokens_requests.append(
                        max(seq.get_output_len()
                            for seq in seq_group.get_seqs()))
1629
1630
                    if seq_group.sampling_params is not None:
                        n_requests.append(seq_group.sampling_params.n)
1631
1632
                        max_tokens_requests.append(
                            seq_group.sampling_params.max_tokens)
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
1645
                actual_num_batched_tokens - num_prompt_tokens_iter +
1646
                num_generation_tokens_from_prefill_groups)
harrywu's avatar
harrywu committed
1647
1648
            num_tokens_iter = (num_generation_tokens_iter +
                               num_prompt_tokens_iter)
1649

1650
1651
        return Stats(
            now=now,
1652
1653
1654
1655
1656
1657
1658
1659
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,
1660
1661
1662
            #   Prefix Cache Hit Rate
            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1663
1664
1665
1666

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
harrywu's avatar
harrywu committed
1667
            num_tokens_iter=num_tokens_iter,
1668
1669
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
1670
            num_preemption_iter=num_preemption_iter,
1671
1672
1673
1674

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
harrywu's avatar
harrywu committed
1675
1676
1677
1678
            time_queue_requests=time_queue_requests,
            time_inference_requests=time_inference_requests,
            time_prefill_requests=time_prefill_requests,
            time_decode_requests=time_decode_requests,
1679
1680
1681
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
harrywu's avatar
harrywu committed
1682
1683
            max_num_generation_tokens_requests=
            max_num_generation_tokens_requests,
1684
            n_requests=n_requests,
1685
            max_tokens_requests=max_tokens_requests,
1686
            finished_reason_requests=finished_reason_requests,
1687
1688
1689
            max_lora=str(max_lora_stat),
            waiting_lora_adapters=list(waiting_lora_adapters.keys()),
            running_lora_adapters=list(running_lora_adapters.keys()))
1690

1691
    def add_lora(self, lora_request: LoRARequest) -> bool:
1692
        return self.model_executor.add_lora(lora_request)
1693
1694

    def remove_lora(self, lora_id: int) -> bool:
1695
        return self.model_executor.remove_lora(lora_id)
1696

1697
    def list_loras(self) -> Set[int]:
1698
        return self.model_executor.list_loras()
1699

1700
1701
1702
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)

1703
1704
1705
1706
1707
1708
    def start_profile(self) -> None:
        self.model_executor.start_profile()

    def stop_profile(self) -> None:
        self.model_executor.stop_profile()

1709
1710
1711
1712
1713
    def sleep(self, level: int = 1) -> None:
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
        self.model_executor.sleep(level=level)

1714
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
1715
1716
        assert self.vllm_config.model_config.enable_sleep_mode, (
            "Sleep mode is not enabled in the model config")
1717
        self.model_executor.wake_up(tags)
1718

1719
1720
1721
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

1722
    def check_health(self) -> None:
1723
        self.model_executor.check_health()
1724
1725
1726
1727

    def is_tracing_enabled(self) -> bool:
        return self.tracer is not None

1728
1729
1730
    def do_tracing(self,
                   scheduler_outputs: SchedulerOutputs,
                   finished_before: Optional[List[int]] = None) -> None:
1731
1732
1733
        if self.tracer is None:
            return

1734
1735
1736
1737
1738
1739
        for idx, scheduled_seq_group in enumerate(
                scheduler_outputs.scheduled_seq_groups):
            # Skip double tracing when using async output proc
            if finished_before and idx in finished_before:
                continue

1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
            seq_group = scheduled_seq_group.seq_group
            if seq_group.is_finished():
                self.create_trace_span(seq_group)

    def create_trace_span(self, seq_group: SequenceGroup) -> None:
        if self.tracer is None or seq_group.sampling_params is None:
            return
        arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)

        trace_context = extract_trace_context(seq_group.trace_headers)

        with self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as seq_span:
            metrics = seq_group.metrics
1757
1758
1759
1760
1761
1762
1763
1764

            # Handle potential None values for cancelled/aborted requests
            ttft = (metrics.first_token_time - metrics.arrival_time
                    if metrics.first_token_time is not None else None)

            e2e_time = (metrics.finished_time - metrics.arrival_time
                        if metrics.finished_time is not None else None)

1765
            seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
1766
                                   self.model_config.model)
1767
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
1768
                                   seq_group.request_id)
1769
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
1770
                                   seq_group.sampling_params.temperature)
1771
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
1772
                                   seq_group.sampling_params.top_p)
1773
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
1774
                                   seq_group.sampling_params.max_tokens)
1775
            seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
1776
                                   seq_group.sampling_params.n)
1777
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
1778
                                   seq_group.num_seqs())
1779
            seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
1780
1781
                                   len(seq_group.prompt_token_ids))
            seq_span.set_attribute(
1782
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
1783
1784
1785
1786
                sum([
                    seq.get_output_len()
                    for seq in seq_group.get_finished_seqs()
                ]))
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798

            # Only set timing attributes if the values are available
            if metrics.time_in_queue is not None:
                seq_span.set_attribute(
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
                    metrics.time_in_queue)
            if ttft is not None:
                seq_span.set_attribute(
                    SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
            if e2e_time is not None:
                seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E,
                                       e2e_time)
1799
1800
            if metrics.scheduler_time is not None:
                seq_span.set_attribute(
1801
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
1802
1803
1804
                    metrics.scheduler_time)
            if metrics.model_forward_time is not None:
                seq_span.set_attribute(
1805
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
1806
1807
1808
                    metrics.model_forward_time / 1000.0)
            if metrics.model_execute_time is not None:
                seq_span.set_attribute(
1809
                    SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
1810
                    metrics.model_execute_time)
1811

1812
    def _validate_model_inputs(self, inputs: ProcessorInputs,
1813
                               lora_request: Optional[LoRARequest]):
1814
1815
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

1816
1817
1818
1819
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs,
                                       lora_request,
                                       prompt_type="encoder")
1820

1821
1822
1823
        self._validate_model_input(decoder_inputs,
                                   lora_request,
                                   prompt_type="decoder")
1824

1825
1826
1827
1828
1829
1830
1831
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        lora_request: Optional[LoRARequest],
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
1832
        model_config = self.model_config
zhuwenwen's avatar
zhuwenwen committed
1833
1834
1835
1836
1837
1838
1839
1840
        if self.tokenizer is None:
            tokenizer = None
        elif self.model_config.tokenizer_mode != "cpm":
            tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
        else:
            tokenizer = self.tokenizer
        # tokenizer = (None if self.tokenizer is None else
        #              self.tokenizer.get_lora_tokenizer(lora_request))
1841

1842
        prompt_ids = prompt_inputs.get("prompt_token_ids", [])
1843
1844
1845
        if not prompt_ids:
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
1846
            elif prompt_inputs["type"] == "embeds":
1847
                pass
1848
1849
1850
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")

1851
1852
1853
1854
1855
1856
        if tokenizer is not None:
            max_input_id = max(prompt_ids, default=0)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    f"Token id {max_input_id} is out of vocabulary")

1857
        max_prompt_len = self.model_config.max_model_len
1858
        if len(prompt_ids) > max_prompt_len:
1859
            if prompt_type == "encoder" and model_config.is_multimodal_model:
1860
1861
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
1862
1863
1864
                    model_config,
                    tokenizer=tokenizer or object(),  # Dummy if no tokenizer
                )
1865
                assert isinstance(mm_processor, EncDecMultiModalProcessor)
1866

1867
1868
1869
                if mm_processor.pad_dummy_encoder_prompt:
                    return  # Skip encoder length check for Whisper

1870
            if model_config.is_multimodal_model:
1871
                suggestion = (
1872
1873
1874
1875
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
1876
1877
1878
1879
1880
1881
1882
1883
1884
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
1885
1886
1887
1888

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
1889
1890
1891
1892

    def _build_logits_processors(
            self, sampling_params: SamplingParams,
            lora_request: Optional[LoRARequest]) -> SamplingParams:
1893
1894
1895
1896
        """Constructs logits processors based on the logits_bias, and
        allowed_token_ids fields in sampling_params. Deletes those fields and
        adds the constructed logits processors to the logits_processors field.
        Returns the modified sampling params."""
1897
1898

        logits_processors = []
1899

1900
1901
1902
        if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
            tokenizer = self.get_tokenizer(lora_request=lora_request)

1903
            processors = get_openai_logits_processors(
1904
1905
1906
1907
1908
1909
1910
1911
1912
                logit_bias=sampling_params.logit_bias,
                allowed_token_ids=sampling_params.allowed_token_ids,
                tokenizer=tokenizer)
            logits_processors.extend(processors)

            # Unset so these don't get passed down to the model
            sampling_params.logit_bias = None
            sampling_params.allowed_token_ids = None

1913
1914
1915
1916
1917
1918
        if len(sampling_params.bad_words) > 0:
            tokenizer = self.get_tokenizer(lora_request)
            processors = get_bad_words_logits_processors(
                bad_words=sampling_params.bad_words, tokenizer=tokenizer)
            logits_processors.extend(processors)

1919
1920
1921
1922
1923
1924
1925
        if logits_processors:
            if sampling_params.logits_processors is None:
                sampling_params.logits_processors = logits_processors
            else:
                sampling_params.logits_processors.extend(logits_processors)

        return sampling_params
1926

1927
1928
1929
1930
1931
1932
1933
1934
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

1935

1936
1937
1938
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
    LLMEngine = V1LLMEngine  # type: ignore