llm.py 87.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
chenzk's avatar
chenzk committed
5
import os
6
from collections.abc import Callable, Iterable, Sequence
7
from pathlib import Path
8
from typing import TYPE_CHECKING, Any
9

10
import cloudpickle
11
import torch.nn as nn
12
from pydantic import ValidationError
13
from tqdm.auto import tqdm
14
from typing_extensions import TypeVar, overload
15

16
17
18
19
20
21
22
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
23
    AttentionConfig,
24
    CompilationConfig,
25
    PoolerConfig,
26
    ProfilerConfig,
27
28
29
    StructuredOutputsConfig,
    is_init_field,
)
30
from vllm.config.compilation import CompilationMode
31
from vllm.config.model import (
32
33
    ConvertOption,
    HfOverrides,
34
    ModelDType,
35
    RunnerOption,
36
    TokenizerMode,
37
)
38
39
40
41
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
42
from vllm.engine.arg_utils import EngineArgs
43
44
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
45
    ChatTemplateConfig,
46
    ChatTemplateContentFormatOption,
47
    load_chat_template,
48
)
49
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
50
from vllm.entrypoints.pooling.score.utils import (
51
    ScoreData,
52
53
54
    ScoreMultiModalParam,
    _cosine_similarity,
    compress_token_type_ids,
55
    compute_maxsim_score,
56
    get_score_prompt,
57
    score_data_to_prompts,
58
    validate_score_input,
59
)
60
from vllm.entrypoints.utils import log_non_default_args
61
from vllm.inputs.data import (
62
    DataPrompt,
63
    ProcessorInputs,
64
65
66
67
68
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
69
from vllm.logger import init_logger
70
from vllm.lora.request import LoRARequest
71
from vllm.model_executor.layers.quantization import QuantizationMethods
72
73
74
75
76
77
78
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
79
from vllm.platforms import current_platform
80
from vllm.pooling_params import PoolingParams
81
from vllm.renderers import ChatParams, merge_kwargs
82
83
84
85
86
from vllm.renderers.inputs.preprocess import (
    conversation_to_seq,
    parse_model_prompt,
    prompt_to_seq,
)
87
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
88
from vllm.tasks import PoolingTask
89
from vllm.tokenizers import TokenizerLike
yhu422's avatar
yhu422 committed
90
from vllm.usage.usage_lib import UsageContext
91
from vllm.utils.counter import Counter
92
from vllm.utils.mistral import is_mistral_tokenizer
93
from vllm.utils.tqdm_utils import maybe_tqdm
94
from vllm.v1.engine import PauseMode
95
from vllm.v1.engine.llm_engine import LLMEngine
96
from vllm.v1.sample.logits_processor import LogitsProcessor
97

98
if TYPE_CHECKING:
chenzk's avatar
chenzk committed
99
    from vllm.kvprune.integration.compression_params import CompressionParams
100
101
    from vllm.v1.metrics.reader import Metric

102
103
logger = init_logger(__name__)

104
105
106
107
108
_O = TypeVar(
    "_O",
    bound=RequestOutput | PoolingRequestOutput,
    default=RequestOutput | PoolingRequestOutput,
)
109
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
110
111
_R = TypeVar("_R", default=Any)

112
113

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
116
117
118
119
120
121
122
123
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
124
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
125
126
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
127
128
129
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
130
131
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
132
133
134
135
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
136
        allowed_media_domains: If set, only media URLs that belong to this
137
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
138
139
140
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
141
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
142
143
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
144
        quantization: The method used to quantize the model weights. Currently,
145
            we support "awq", "gptq", and "fp8" (experimental).
146
147
148
149
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
150
151
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
152
153
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
154
        chat_template: The chat template to apply.
155
156
157
158
159
160
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
161
162
163
164
165
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
166
            compared with using gpu_memory_utilization. Note that
167
168
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
169
170
171
172
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
173
174
175
176
177
178
179
180
181
182
183
184
185
        offload_group_size: Prefetch offloading: Group every N layers
            together. Offload last `offload_num_in_group` layers of each group.
            Default is 0 (disabled).
        offload_num_in_group: Prefetch offloading: Number of layers to
            offload per group. Default is 1.
        offload_prefetch_step: Prefetch offloading: Number of layers to
            prefetch ahead. Higher values hide more latency but use more GPU
            memory. Default is 1.
        offload_params: Prefetch offloading: Set of parameter name segments
            to selectively offload. Only parameters whose names contain one of
            these segments will be offloaded (e.g., {"gate_up_proj", "down_proj"}
            for MLP weights, or {"w13_weight", "w2_weight"} for MoE expert
            weights). If None or empty, all parameters are offloaded.
186
187
188
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
chenzk's avatar
chenzk committed
189
190
191
192
193
194
195
196
197
        kvprune_compression: If True, sets ``enforce_eager=True`` for the **v1**
            engine only (no v1 CUDA graph capture). If ``None`` (default), read
            ``VLLM_KVPRUNE_COMPRESSION_DEFAULT`` (``"0"`` = allow v1 graphs;
            ``"1"`` = skip v1 graphs). This is independent of the compactor's
            ``LLMConfig.enforce_eager`` (see ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` /
            ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER``; default tries compactor graphs).
            When True, v1's GPU KV pool defaults to **one** block (minimum allowed by
            the scheduler) unless ``num_gpu_blocks_override`` is passed in ``**kwargs``
            or ``VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS`` is set (``auto`` = profiled allocation).
198
        enable_return_routed_experts: Whether to return routed experts.
199
200
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
201
        hf_token: The token to use as HTTP bearer authorization for remote files
202
            . If `True`, will use the token generated when running
203
            `hf auth login` (stored in `~/.cache/huggingface/token`).
204
205
206
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
207
208
209
210
211
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
212
213
        pooler_config: Initialize non-default pooling config for the pooling model,
            e.g., `PoolerConfig(seq_pooling_type="MEAN", use_activation=False)`.
214
        compilation_config: Either an integer or a dictionary. If it is an
215
            integer, it is used as the mode of compilation optimization. If it
216
            is a dictionary, it can specify the full compilation configuration.
217
218
219
220
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
221
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
222

223
224
    Note:
        This class is intended to be used for offline inference. For online
225
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
226
    """
227
228
229
230

    def __init__(
        self,
        model: str,
231
        *,
232
233
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
234
        tokenizer: str | None = None,
235
        tokenizer_mode: TokenizerMode | str = "auto",
236
        skip_tokenizer_init: bool = False,
237
        trust_remote_code: bool = False,
238
        allowed_local_media_path: str = "",
239
        allowed_media_domains: list[str] | None = None,
240
        tensor_parallel_size: int = 1,
241
        dtype: ModelDType = "auto",
242
243
244
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
245
        chat_template: Path | str | None = None,
246
        seed: int = 0,
247
        gpu_memory_utilization: float = 0.9,
248
        cpu_offload_gb: float = 0,
249
250
251
252
        offload_group_size: int = 0,
        offload_num_in_group: int = 1,
        offload_prefetch_step: int = 1,
        offload_params: set[str] | None = None,
253
        enforce_eager: bool = False,
chenzk's avatar
chenzk committed
254
        kvprune_compression: bool | None = None,
255
        enable_return_routed_experts: bool = False,
256
        disable_custom_all_reduce: bool = False,
257
258
259
260
261
262
263
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
264
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
265
        attention_config: dict[str, Any] | AttentionConfig | None = None,
266
267
268
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
269
        **kwargs: Any,
270
    ) -> None:
271
        """LLM constructor."""
272

273
274
275
276
277
278
279
280
281
282
283
        if "swap_space" in kwargs:
            kwargs.pop("swap_space")
            import warnings

            warnings.warn(
                "The 'swap_space' parameter is deprecated and ignored. "
                "It will be removed in a future version.",
                DeprecationWarning,
                stacklevel=2,
            )

284
285
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
286

287
288
289
290
291
292
293
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

294
        if "kv_transfer_config" in kwargs and isinstance(
295
296
            kwargs["kv_transfer_config"], dict
        ):
297
            from vllm.config.kv_transfer import KVTransferConfig
298

299
300
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
301
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
302
303
304
305
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
306
307
308
                    raw_config_dict,
                    e,
                )
309
310
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
311
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
312

313
314
315
        if hf_overrides is None:
            hf_overrides = {}

316
317
318
319
320
321
322
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
323

324
325
326
327
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
328
        else:
329
330
331
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
332

333
334
335
336
337
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
338

339
        # warn about single-process data parallel usage.
340
341
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
342
343
344
345
346
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
347
            raise ValueError(
348
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
349
350
351
352
353
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

chenzk's avatar
chenzk committed
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        # v1 ``enforce_eager`` is independent of kvprune compactor ``LLMConfig.enforce_eager``.
        if kvprune_compression is None:
            _kvd = os.environ.get("VLLM_KVPRUNE_COMPRESSION_DEFAULT", "0").strip().lower()
            kvprune_compression = _kvd in ("1", "true", "yes")
        if kvprune_compression:
            enforce_eager = True
            # Reserve minimal v1 GPU KV so compactor can use the rest of VRAM. v1
            # scheduler requires num_gpu_blocks >= 1; profiling would allocate a
            # large pool from gpu_memory_utilization. Override:
            #   VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS unset  -> 1 block (default)
            #   VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS=auto   -> profiled (no override)
            #   VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS=<int>   -> max(1, int)
            if "num_gpu_blocks_override" not in kwargs:
                _v1_kv = os.environ.get("VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS", "").strip()
                if _v1_kv.lower() in ("auto", "profile"):
                    pass
                elif not _v1_kv:
                    kwargs["num_gpu_blocks_override"] = 1
                else:
                    kwargs["num_gpu_blocks_override"] = max(1, int(_v1_kv))
Zhuohan Li's avatar
Zhuohan Li committed
374
        engine_args = EngineArgs(
375
            model=model,
376
377
            runner=runner,
            convert=convert,
378
            tokenizer=tokenizer,
379
            tokenizer_mode=tokenizer_mode,
380
            skip_tokenizer_init=skip_tokenizer_init,
381
            trust_remote_code=trust_remote_code,
382
            allowed_local_media_path=allowed_local_media_path,
383
            allowed_media_domains=allowed_media_domains,
384
385
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
386
            quantization=quantization,
387
            revision=revision,
388
            tokenizer_revision=tokenizer_revision,
389
390
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
391
            kv_cache_memory_bytes=kv_cache_memory_bytes,
392
            cpu_offload_gb=cpu_offload_gb,
393
394
395
396
            offload_group_size=offload_group_size,
            offload_num_in_group=offload_num_in_group,
            offload_prefetch_step=offload_prefetch_step,
            offload_params=offload_params or set(),
397
            enforce_eager=enforce_eager,
398
            enable_return_routed_experts=enable_return_routed_experts,
399
            disable_custom_all_reduce=disable_custom_all_reduce,
400
            hf_token=hf_token,
401
            hf_overrides=hf_overrides,
402
            mm_processor_kwargs=mm_processor_kwargs,
403
            pooler_config=pooler_config,
404
            structured_outputs_config=structured_outputs_instance,
405
            profiler_config=profiler_config_instance,
406
            attention_config=attention_config_instance,
407
            compilation_config=compilation_config_instance,
408
            logits_processors=logits_processors,
409
410
            **kwargs,
        )
411

412
413
        log_non_default_args(engine_args)

414
        self.llm_engine = LLMEngine.from_engine_args(
415
416
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
417
        self.engine_class = type(self.llm_engine)
418

419
        self.request_counter = Counter()
420
        self.default_sampling_params: dict[str, Any] | None = None
421

422
423
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
424
425
        self.supported_tasks = supported_tasks

426
        self.model_config = self.llm_engine.model_config
427
        self.renderer = self.llm_engine.renderer
428
        self.chat_template = load_chat_template(chat_template)
429
        self.io_processor = self.llm_engine.io_processor
430
        self.input_processor = self.llm_engine.input_processor
431
        self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
432
        self.pooling_io_processors = init_pooling_io_processors(
433
434
435
436
437
            supported_tasks=supported_tasks,
            model_config=self.model_config,
            renderer=self.renderer,
            chat_template_config=self.chat_template_config,
        )
438
439
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None
chenzk's avatar
chenzk committed
440
441
442
        # Lazy compactor engine (``vllm.kvprune``) when :meth:`generate` uses compression.
        self._kvprune_compactor_engine: Any = None
        self._kvprune_compression_enabled = bool(kvprune_compression)
443

444
    def get_tokenizer(self) -> TokenizerLike:
445
        return self.llm_engine.get_tokenizer()
446

447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    def get_world_size(self, include_dp: bool = True) -> int:
        """Get the world size from the parallel config.

        Args:
            include_dp: If True (default), returns the world size including
                data parallelism (TP * PP * DP). If False, returns the world
                size without data parallelism (TP * PP).

        Returns:
            The world size (tensor_parallel_size * pipeline_parallel_size),
            optionally multiplied by data_parallel_size if include_dp is True.
        """
        parallel_config = self.llm_engine.vllm_config.parallel_config
        if include_dp:
            return parallel_config.world_size_across_dp
        return parallel_config.world_size

464
    def reset_mm_cache(self) -> None:
465
        self.renderer.clear_mm_cache()
466
467
        self.llm_engine.reset_mm_cache()

468
    def get_default_sampling_params(self) -> SamplingParams:
469
        if self.default_sampling_params is None:
470
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
471
472
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
473
474
        return SamplingParams()

475
476
    def generate(
        self,
477
478
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
479
        *,
480
        use_tqdm: bool | Callable[..., tqdm] = True,
481
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
482
        priority: list[int] | None = None,
483
        tokenization_kwargs: dict[str, Any] | None = None,
chenzk's avatar
chenzk committed
484
        compression: "CompressionParams | Sequence[CompressionParams] | None" = None,
485
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
486
487
        """Generates the completions for the input prompts.

488
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
489
490
491
492
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
493
            prompts: The prompts to the LLM. You may pass a sequence of prompts
494
                for batch inference. See [PromptType][vllm.inputs.PromptType]
495
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
496
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
497
498
499
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
500
                prompts and it is paired one by one with the prompt.
501
502
503
504
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
505
            lora_request: LoRA request to use for generation, if any.
506
507
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
508
509
510
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
511
            tokenization_kwargs: Overrides for `tokenizer.encode`.
chenzk's avatar
chenzk committed
512
513
514
515
516
517
518
519
520
            compression: Optional per-prompt KV compression (``vllm.kvprune``). If any
                prompt has ``compression_ratio < 1.0``, the batch is run on the integrated
                compactor engine with weights shared from this ``LLM``. Omit or use all
                ``compression_ratio >= 1`` to use the standard v1 engine only.
                Use ``kvprune_compression=True`` or ``VLLM_KVPRUNE_COMPRESSION_DEFAULT=1``
                so the v1 engine skips CUDA graph capture. Compactor decode graphs
                default on (``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` default ``1``) with
                eager fallback if capture fails; set ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1``
                to skip compactor graph capture entirely.
Woosuk Kwon's avatar
Woosuk Kwon committed
521
522

        Returns:
523
            A list of `RequestOutput` objects containing the
524
525
            generated completions in the same order as the input prompts.
        """
526
        runner_type = self.model_config.runner_type
527
        if runner_type != "generate":
528
529
530
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
531
532
                "generative model."
            )
chenzk's avatar
chenzk committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        compression_eff = compression
        if compression is None and getattr(self, "_kvprune_compression_enabled", False):
            pc = self.llm_engine.vllm_config.parallel_config
            if (
                pc.tensor_parallel_size > 1
                and pc.pipeline_parallel_size == 1
                and pc.data_parallel_size == 1
            ):
                from vllm.kvprune.integration.compression_params import CompressionParams
                from vllm.kvprune.integration.compressed_generate import (
                    _normalize_prompt_list,
                )

                _plist = _normalize_prompt_list(prompts)
                compression_eff = [
                    CompressionParams(compression_ratio=1.0) for _ in _plist
                ]

        if compression_eff is not None:
            from vllm.kvprune.integration.compressed_generate import (
                try_compressed_generate,
            )

            compressed_out = try_compressed_generate(
                self,
                prompts,
                sampling_params,
                compression=compression_eff,
                use_tqdm=use_tqdm,
                lora_request=lora_request,
                priority=priority,
                tokenization_kwargs=tokenization_kwargs,
            )
            if compressed_out is not None:
                return compressed_out
568

569
        if sampling_params is None:
570
            sampling_params = self.get_default_sampling_params()
571

572
        return self._run_completion(
573
            prompts=prompts,
574
            params=sampling_params,
575
            output_type=RequestOutput,
576
            use_tqdm=use_tqdm,
577
            lora_request=lora_request,
578
            tokenization_kwargs=tokenization_kwargs,
579
580
            priority=priority,
        )
581

582
583
584
585
    def enqueue(
        self,
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
586
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
        priority: list[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[str]:
        """Enqueue prompts for generation without waiting for completion.

        This method adds requests to the engine queue but does not start
        processing them. Use wait_for_completion() to process the queued
        requests and get results.

        Args:
            prompts: The prompts to the LLM. See generate() for details.
            sampling_params: The sampling parameters for text generation.
            lora_request: LoRA request to use for generation, if any.
            priority: The priority of the requests, if any.
            use_tqdm: If True, shows a tqdm progress bar while adding requests.
            tokenization_kwargs: Overrides for `tokenizer.encode`.

        Returns:
            A list of request IDs for the enqueued requests.
        """
608
        runner_type = self.model_config.runner_type
609
610
611
612
613
614
        if runner_type != "generate":
            raise ValueError("LLM.enqueue() is only supported for generative models.")

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

615
616
617
618
619
620
621
        return self._add_completion_requests(
            prompts=prompts,
            params=sampling_params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            priority=priority,
            tokenization_kwargs=tokenization_kwargs,
622
623
        )

624
    @overload
625
626
    def wait_for_completion(
        self,
627
        *,
628
        use_tqdm: bool | Callable[..., tqdm] = True,
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
    ) -> list[RequestOutput | PoolingRequestOutput]: ...

    @overload
    def wait_for_completion(
        self,
        output_type: type[_O] | tuple[type[_O], ...],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[_O]: ...

    def wait_for_completion(
        self,
        output_type: type[Any] | tuple[type[Any], ...] | None = None,
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[Any]:
645
646
647
648
649
650
        """Wait for all enqueued requests to complete and return results.

        This method processes all requests currently in the engine queue
        and returns their outputs. Use after enqueue() to get results.

        Args:
651
            output_type: The expected output type, defaults to RequestOutput.
652
653
654
            use_tqdm: If True, shows a tqdm progress bar.

        Returns:
655
            A list of output objects for all completed requests.
656
        """
657
658
659
660
        if output_type is None:
            output_type = (RequestOutput, PoolingRequestOutput)

        return self._run_engine(output_type, use_tqdm=use_tqdm)
661

Cyrus Leung's avatar
Cyrus Leung committed
662
    def _resolve_mm_lora(
663
        self,
664
        prompt: ProcessorInputs,
665
        lora_request: LoRARequest | None,
Cyrus Leung's avatar
Cyrus Leung committed
666
667
668
669
670
671
672
    ) -> LoRARequest | None:
        if prompt["type"] != "multimodal":
            return lora_request

        lora_config = self.llm_engine.vllm_config.lora_config
        default_mm_loras = None if lora_config is None else lora_config.default_mm_loras
        if not default_mm_loras:
673
674
            return lora_request

675
676
        prompt_modalities = prompt["mm_placeholders"].keys()
        intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
677
678
        if not intersection:
            return lora_request
Cyrus Leung's avatar
Cyrus Leung committed
679

680
681
682
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
Cyrus Leung's avatar
Cyrus Leung committed
683
684
685
686
                "Multiple modality specific loras were registered and would be "
                "used by a single prompt consuming several modalities; "
                "currently we only support one lora per request; as such, "
                "lora(s) registered with modalities: %s will be skipped",
687
688
                intersection,
            )
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
704
705
                    "lora_request as we only apply one LoRARequest per prompt"
                )
706
707
708
709
710
711
712
713
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

714
715
    def collective_rpc(
        self,
716
717
        method: str | Callable[..., _R],
        timeout: float | None = None,
718
        args: tuple = (),
719
        kwargs: dict[str, Any] | None = None,
720
    ) -> list[_R]:
721
722
723
724
725
726
727
728
729
730
731
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
732
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
733
734
735
736
737
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
738

739
740
741
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
742
        """
743
744

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
745
746

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
747
        """
748
749
        Run a function directly on the model inside each worker,
        returning the result for each of them.
750
751
752
753
754
755

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
756
        """
757
        return self.llm_engine.apply_model(func)
758

759
760
    def beam_search(
        self,
761
        prompts: list[TokensPrompt | TextPrompt],
762
        params: BeamSearchParams,
763
        lora_request: list[LoRARequest] | LoRARequest | None = None,
764
        use_tqdm: bool = False,
765
        concurrency_limit: int | None = None,
766
    ) -> list[BeamSearchOutput]:
767
768
769
770
771
772
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
773
            params: The beam search parameters.
774
            lora_request: LoRA request to use for generation, if any.
775
            use_tqdm: Whether to use tqdm to display the progress bar.
776
777
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
778
        """
779
780
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
781
782
783
784
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
785
786
        length_penalty = params.length_penalty

787
788
789
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
790

791
792
        engine_prompts = self._preprocess_cmpl(prompts)
        lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts))
793

794
795
796
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
797
798
                "Disabling progress bar."
            )
799
800
801
            use_tqdm = False

        if concurrency_limit is None:
802
            concurrency_limit = len(engine_prompts)
803

804
805
806
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
807
        sampling_params = SamplingParams(
808
809
810
811
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
812
        )
813
        instances: list[BeamSearchInstance] = []
814

815
816
817
818
819
        for lora_req, prompt in zip(lora_requests, engine_prompts):
            if prompt["type"] == "embeds":
                raise NotImplementedError(
                    "Embedding prompt not supported for beam search"
                )
820

821
            instances.append(
822
                BeamSearchInstance(
823
                    prompt,
824
825
                    lora_request=lora_req,
                    logprobs=None,
826
827
                ),
            )
828

829
        for prompt_start in range(0, len(instances), concurrency_limit):
830
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
831
832
833

            token_iter = range(max_tokens)
            if use_tqdm:
834
835
836
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
837
838
839
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
840
841
                    "reflect instance-level progress."
                )
842
843
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
844
845
                    sum((instance.beams for instance in instances_batch), [])
                )
846
847
                pos = [0] + list(
                    itertools.accumulate(
848
849
850
                        len(instance.beams) for instance in instances_batch
                    )
                )
851
                instance_start_and_end: list[tuple[int, int]] = list(
852
853
                    zip(pos[:-1], pos[1:])
                )
854
855
856
857
858
859

                if len(all_beams) == 0:
                    break

                # only runs for one step
                # we don't need to use tqdm here
860
                output = self._render_and_run_requests(
861
862
                    prompts=(beam.get_prompt() for beam in all_beams),
                    params=self._params_to_seq(sampling_params, len(all_beams)),
863
                    output_type=RequestOutput,
864
                    lora_requests=[beam.lora_request for beam in all_beams],
865
866
                    use_tqdm=False,
                )
867

868
869
870
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
871
872
873
874
875
876
877
878
879
880
881
882
883
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
884
                                    current_beam.orig_prompt,
885
                                    tokens=current_beam.tokens + [token_id],
886
                                    logprobs=current_beam.logprobs + [logprobs],
887
                                    lora_request=current_beam.lora_request,
888
889
890
891
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                )

892
                                if token_id == eos_token_id and not ignore_eos:
893
894
895
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
896
897
898
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
899
                    instance.beams = sorted_beams[:beam_width]
900
901
902
903

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
904
905
906
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
907
908
909
910
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
911

912
913
914
915
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

916
    def _preprocess_cmpl(
917
        self,
918
        prompts: Sequence[PromptType],
919
        tokenization_kwargs: dict[str, Any] | None = None,
920
    ) -> Sequence[ProcessorInputs]:
921
922
923
924
925
926
927
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
928
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
929
        """
930
        renderer = self.renderer
931
932
        model_config = self.model_config

933
934
935
        parsed_prompts = [
            parse_model_prompt(model_config, prompt) for prompt in prompts
        ]
936
937
938
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
939

940
        return renderer.render_cmpl(parsed_prompts, tok_params)
941

942
943
944
945
946
947
948
949
    def _preprocess_cmpl_one(
        self,
        prompt: PromptType,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
        return engine_prompt

950
951
    def _preprocess_chat(
        self,
952
        conversations: Sequence[list[ChatCompletionMessageParam]],
953
        chat_template: str | None = None,
954
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
955
        chat_template_kwargs: dict[str, Any] | None = None,
956
        add_generation_prompt: bool = True,
957
        continue_final_message: bool = False,
958
        tools: list[dict[str, Any]] | None = None,
959
        tokenization_kwargs: dict[str, Any] | None = None,
960
        mm_processor_kwargs: dict[str, Any] | None = None,
961
    ) -> Sequence[ProcessorInputs]:
nunjunj's avatar
nunjunj committed
962
        """
963
964
965
966
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
967
968

        Returns:
969
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
nunjunj's avatar
nunjunj committed
970
        """
971
        renderer = self.renderer
972

973
974
975
976
977
978
979
980
981
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
982
                    tokenize=is_mistral_tokenizer(renderer.tokenizer),
983
984
985
                ),
            ),
        )
986
987
988
        tok_params = renderer.default_chat_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
989

990
991
992
993
994
995
        _, engine_prompts = renderer.render_chat(
            conversations,
            chat_params,
            tok_params,
            prompt_extras={"mm_processor_kwargs": mm_processor_kwargs},
        )
996

997
        return engine_prompts
998

999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
    def _preprocess_chat_one(
        self,
        conversation: list[ChatCompletionMessageParam],
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        chat_template_kwargs: dict[str, Any] | None = None,
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_chat(
            [conversation],
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=chat_template_kwargs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            tokenization_kwargs=tokenization_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

        return engine_prompt

1025
1026
    def chat(
        self,
1027
        messages: list[ChatCompletionMessageParam]
1028
1029
        | Sequence[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
1030
        use_tqdm: bool | Callable[..., tqdm] = True,
1031
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1032
        chat_template: str | None = None,
1033
1034
1035
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
1036
1037
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
1038
        tokenization_kwargs: dict[str, Any] | None = None,
1039
        mm_processor_kwargs: dict[str, Any] | None = None,
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
1052
            messages: A sequence of conversations or a single conversation.
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
1084
1085
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
1086
1087
1088
1089
1090

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError(
                "LLM.chat() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model."
            )

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

1103
        return self._run_chat(
1104
1105
            messages=messages,
            params=sampling_params,
1106
            output_type=RequestOutput,
1107
1108
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1109
1110
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
1111
            chat_template_kwargs=chat_template_kwargs,
1112
1113
1114
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
1115
            tokenization_kwargs=tokenization_kwargs,
1116
1117
1118
            mm_processor_kwargs=mm_processor_kwargs,
        )

1119
1120
    def encode(
        self,
1121
1122
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1123
        *,
1124
1125
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1126
        pooling_task: PoolingTask | None = None,
1127
        tokenization_kwargs: dict[str, Any] | None = None,
1128
    ) -> list[PoolingRequestOutput]:
1129
1130
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1131

1132
        This class automatically batches the given prompts, considering
1133
1134
1135
1136
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1137
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1138
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1139
                for more details about the format of each prompt.
1140
1141
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1142
1143
1144
1145
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1146
            lora_request: LoRA request to use for generation, if any.
1147
            pooling_task: Override the pooling task to use.
1148
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1149
1150

        Returns:
1151
            A list of `PoolingRequestOutput` objects containing the
1152
            pooled hidden states in the same order as the input prompts.
1153
        """
1154

1155
        if pooling_task is None:
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )
1171

1172
        model_config = self.model_config
1173
        runner_type = model_config.runner_type
1174
        if runner_type != "pooling":
1175
1176
1177
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1178
1179
                "pooling model."
            )
1180

1181
        if isinstance(prompts, dict) and "data" in prompts:
1182
1183
1184
1185
1186
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1187
1188
                    "offline inference example for more details."
                )
1189
1190

            # Validate the request data is valid for the loaded plugin
1191
1192
1193
1194
1195
1196
1197
1198
1199
            prompt_data = prompts.get("data")
            if prompt_data is None:
                raise ValueError(
                    "The 'data' field of the prompt is expected to contain "
                    "the prompt data and it cannot be None. "
                    "Refer to the documentation of the IOProcessor "
                    "in use for more details."
                )
            validated_prompt = self.io_processor.parse_data(prompt_data)
1200
1201
1202

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1203
            prompts_seq = prompt_to_seq(prompts)
1204

1205
1206
1207
1208
1209
            params_seq: Sequence[PoolingParams] = [
                self.io_processor.merge_pooling_params(param)
                for param in self._params_to_seq(
                    pooling_params,
                    len(prompts_seq),
1210
                )
1211
1212
1213
1214
            ]
            for p in params_seq:
                if p.task is None:
                    p.task = "plugin"
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239

            outputs = self._run_completion(
                prompts=prompts_seq,
                params=params_seq,
                output_type=PoolingRequestOutput,
                use_tqdm=use_tqdm,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
            )

            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(outputs)

            return [
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
                    prompt_token_ids=[],
                    finished=True,
                )
            ]
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

            prompts_seq = prompt_to_seq(prompts)
            params_seq = self._params_to_seq(pooling_params, len(prompts_seq))

            for param in params_seq:
                if param.task is None:
                    param.task = pooling_task
                elif param.task != pooling_task:
                    msg = (
                        f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                    )
                    raise ValueError(msg)
1256

1257
1258
            if pooling_task in self.pooling_io_processors:
                io_processor = self.pooling_io_processors[pooling_task]
1259
1260
1261
1262
1263
1264
1265
                processor_inputs = io_processor.pre_process_offline(
                    prompts_seq, tokenization_kwargs
                )
                seq_lora_requests = self._lora_request_to_seq(
                    lora_request, len(prompts_seq)
                )
                seq_priority = self._priority_to_seq(None, len(prompts))
1266

1267
1268
1269
1270
1271
                self._render_and_add_requests(
                    prompts=processor_inputs,
                    params=params_seq,
                    lora_requests=seq_lora_requests,
                    priorities=seq_priority,
1272
                )
1273

1274
1275
1276
                outputs = self._run_engine(
                    use_tqdm=use_tqdm, output_type=PoolingRequestOutput
                )
1277
                outputs = io_processor.post_process_offline(outputs)
1278
1279
1280
1281
1282
1283
1284
1285
1286
            else:
                outputs = self._run_completion(
                    prompts=prompts_seq,
                    params=params_seq,
                    output_type=PoolingRequestOutput,
                    use_tqdm=use_tqdm,
                    lora_request=lora_request,
                    tokenization_kwargs=tokenization_kwargs,
                )
1287
        return outputs
1288

1289
1290
    def embed(
        self,
1291
        prompts: PromptType | Sequence[PromptType],
1292
        *,
1293
1294
1295
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1296
        tokenization_kwargs: dict[str, Any] | None = None,
1297
    ) -> list[EmbeddingRequestOutput]:
1298
1299
1300
1301
1302
1303
1304
1305
1306
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1307
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1308
                for more details about the format of each prompt.
1309
1310
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1311
1312
1313
1314
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1315
            lora_request: LoRA request to use for generation, if any.
1316
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1317
1318

        Returns:
1319
            A list of `EmbeddingRequestOutput` objects containing the
1320
1321
            embedding vectors in the same order as the input prompts.
        """
1322
        if "embed" not in self.supported_tasks:
1323
1324
            raise ValueError(
                "Embedding API is not supported by this model. "
1325
1326
                "Try converting the model using `--convert embed`."
            )
1327

1328
1329
1330
1331
1332
1333
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1334
            tokenization_kwargs=tokenization_kwargs,
1335
        )
1336
1337
1338
1339
1340

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1341
        prompts: PromptType | Sequence[PromptType],
1342
        *,
1343
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1344
        use_tqdm: bool | Callable[..., tqdm] = True,
1345
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1346
        tokenization_kwargs: dict[str, Any] | None = None,
1347
    ) -> list[ClassificationRequestOutput]:
1348
1349
1350
1351
1352
1353
1354
1355
1356
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1357
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1358
                for more details about the format of each prompt.
1359
1360
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1361
1362
1363
1364
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1365
            lora_request: LoRA request to use for generation, if any.
1366
1367
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1368
        Returns:
1369
            A list of `ClassificationRequestOutput` objects containing the
1370
1371
            embedding vectors in the same order as the input prompts.
        """
1372
        if "classify" not in self.supported_tasks:
1373
            raise ValueError(
1374
                "Classification API is not supported by this model. "
1375
1376
                "Try converting the model using `--convert classify`."
            )
1377

1378
1379
1380
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1381
            pooling_params=pooling_params,
1382
1383
            lora_request=lora_request,
            pooling_task="classify",
1384
            tokenization_kwargs=tokenization_kwargs,
1385
        )
1386
1387
1388

        return [ClassificationRequestOutput.from_base(item) for item in items]

1389
1390
    def reward(
        self,
1391
        prompts: PromptType | Sequence[PromptType],
1392
1393
        /,
        *,
1394
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1395
1396
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1397
        tokenization_kwargs: dict[str, Any] | None = None,
1398
1399
1400
1401
1402
1403
1404
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1405
                for more details about the format of each prompt.
1406
1407
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1408
1409
1410
1411
1412
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1413
1414
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1415
1416
1417
1418
1419
1420
1421
1422
1423
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """
        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
1424
            pooling_task="token_classify",
1425
            tokenization_kwargs=tokenization_kwargs,
1426
1427
        )

1428
1429
    def _embedding_score(
        self,
1430
1431
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1432
1433
1434
1435
1436
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
1437
    ) -> list[ScoringRequestOutput]:
1438
1439
        tokenizer = self.get_tokenizer()

1440
1441
1442
1443
1444
1445
1446
1447
        input_texts: list[str] = []
        for text in data_1 + data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Embedding scores currently do not support multimodal input."
                )
            input_texts.append(text)

1448
        encoded_output = self.encode(
1449
            input_texts,
1450
1451
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1452
            pooling_params=pooling_params,
1453
            pooling_task="embed",
1454
            tokenization_kwargs=tokenization_kwargs,
1455
        )
1456

1457
1458
        encoded_output_1 = encoded_output[0 : len(data_1)]
        encoded_output_2 = encoded_output[len(data_1) :]
1459
1460
1461
1462

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1463
        scores = _cosine_similarity(
1464
1465
1466
            tokenizer=tokenizer,
            embed_1=encoded_output_1,
            embed_2=encoded_output_2,
1467
        )
1468

1469
        return [ScoringRequestOutput.from_base(item) for item in scores]
1470

1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
    def _late_interaction_score(
        self,
        data_1: list[ScoreData],
        data_2: list[ScoreData],
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
    ) -> list[ScoringRequestOutput]:
        """
        Late interaction scoring (ColBERT MaxSim).

        Encodes queries and documents into per-token embeddings, then computes
        MaxSim: sum over query tokens of max similarity to any document token.
        """
        from vllm.outputs import PoolingOutput

        tokenizer = self.get_tokenizer()

1491
1492
1493
1494
        # Convert ScoreData to PromptType (handles both text and multimodal)
        model_config = self.model_config
        prompts_1 = score_data_to_prompts(data_1, "query", model_config)
        prompts_2 = score_data_to_prompts(data_2, "document", model_config)
1495

1496
1497
        encoded_output: list[PoolingRequestOutput] = self.encode(
            prompts_1 + prompts_2,
1498
1499
1500
1501
1502
1503
1504
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            pooling_task="token_embed",
            tokenization_kwargs=tokenization_kwargs,
        )

1505
1506
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[: len(prompts_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(prompts_1) :]
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

        # Compute MaxSim scores
        scores: list[PoolingRequestOutput] = []
        padding: list[int] = []
        if (pad_token_id := tokenizer.pad_token_id) is not None:
            padding = [pad_token_id]

        for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
            # emb_1.outputs.data: [query_len, dim]
            # emb_2.outputs.data: [doc_len, dim]
            q_emb = emb_1.outputs.data
            d_emb = emb_2.outputs.data

            maxsim_score = compute_maxsim_score(q_emb, d_emb)

            tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

            scores.append(
                PoolingRequestOutput(
                    request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                    outputs=PoolingOutput(data=maxsim_score),
                    prompt_token_ids=tokens,
                    num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
                    finished=True,
                )
            )

1537
        return [ScoringRequestOutput.from_base(item) for item in scores]
1538

1539
1540
    def _cross_encoding_score(
        self,
1541
1542
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1543
1544
1545
1546
1547
1548
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
        score_template: str | None,
1549
    ) -> list[ScoringRequestOutput]:
1550
        model_config = self.model_config
1551
        tokenizer = self.get_tokenizer()
1552

1553
        if is_mistral_tokenizer(tokenizer):
1554
            raise ValueError("Score API is not supported for Mistral tokenizer")
1555

1556
1557
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1558

1559
1560
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")
1561
1562
        elif pooling_params.task is None:
            pooling_params.task = "score"
1563

1564
        pooling_params_list = list[PoolingParams]()
1565

1566
        prompts = list[PromptType]()
1567

1568
1569
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1570
1571
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1572
                model_config=model_config,
1573
1574
1575
1576
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1577
                score_template=score_template,
1578
1579
            )

1580
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1581
1582
1583
1584
1585
1586
1587
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1588
            prompts.append(engine_prompt)
1589

1590
        outputs = self._run_completion(
1591
            prompts=prompts,
1592
            params=pooling_params_list,
1593
            output_type=PoolingRequestOutput,
1594
            use_tqdm=use_tqdm,
1595
1596
1597
            lora_request=lora_request,
        )

1598
        return [ScoringRequestOutput.from_base(item) for item in outputs]
1599

1600
1601
    def score(
        self,
1602
1603
1604
1605
1606
1607
1608
1609
        data_1: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
        data_2: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
1610
        /,
1611
        *,
1612
1613
1614
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1615
        tokenization_kwargs: dict[str, Any] | None = None,
1616
        chat_template: str | None = None,
1617
    ) -> list[ScoringRequestOutput]:
1618
1619
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1620

1621
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1622
1623
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1624
        The input pairs are used to build a list of prompts for the
1625
1626
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1627
1628
1629
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1630
        appropriate multi-modal models. For multi-modal inputs, ensure the
1631
        prompt structure matches the model's expected input format.
1632
1633

        Args:
1634
1635
1636
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1637
                the `data_2` list.
1638
            data_2: The data to pair with the query to form the input to
1639
                the LLM. Can be text or multi-modal data. See [PromptType]
1640
                [vllm.inputs.PromptType] for more details about the format of
1641
                each prompt.
1642
1643
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1644
1645
1646
1647
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1648
            lora_request: LoRA request to use for generation, if any.
1649
1650
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1651
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1652
        Returns:
1653
            A list of `ScoringRequestOutput` objects containing the
1654
1655
            generated scores in the same order as the input prompts.
        """
1656
        model_config = self.model_config
1657

1658
        runner_type = model_config.runner_type
1659
        if runner_type != "pooling":
1660
1661
1662
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1663
1664
                "pooling model."
            )
1665

1666
        supported_tasks = self.supported_tasks
1667
1668
1669
1670
        score_type = self.model_config.score_type
        is_late_interaction = score_type == "late-interaction"
        is_cross_encoder = score_type == "cross-encoder"

1671
1672
1673
1674
        # Late interaction models (e.g., ColBERT) use token_embed for scoring
        if not is_late_interaction and all(
            t not in supported_tasks for t in ("embed", "classify")
        ):
1675
1676
1677
1678
1679
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1680

1681
        if is_cross_encoder and getattr(model_config.hf_config, "num_labels", 0) != 1:
1682
            raise ValueError("Score API is only enabled for num_labels == 1.")
1683

1684
        if not is_cross_encoder and chat_template is not None:
1685
1686
1687
1688
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1689
1690
        is_multimodal_model = model_config.is_multimodal_model
        architecture = model_config.architecture
1691

1692
1693
1694
1695
1696
1697
        score_data_1, score_data_2 = validate_score_input(
            data_1,  # type: ignore[arg-type]
            data_2,  # type: ignore[arg-type]
            is_multimodal_model=is_multimodal_model,
            architecture=architecture,
        )
1698

1699
1700
1701
1702
        renderer = self.renderer
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
1703
1704
        encode_kwargs = tok_params.get_encode_kwargs()

1705
        if is_cross_encoder:
1706
            return self._cross_encoding_score(
1707
1708
                score_data_1,
                score_data_2,
1709
1710
1711
1712
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1713
                score_template=chat_template,
1714
            )
1715
1716
1717
1718
1719
1720
1721
1722
1723
        elif is_late_interaction:
            return self._late_interaction_score(
                score_data_1,
                score_data_2,
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
            )
1724
        else:
1725
            return self._embedding_score(
1726
1727
                score_data_1,
                score_data_2,
1728
1729
1730
1731
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1732
            )
1733

1734
1735
1736
1737
1738
1739
1740
1741
1742
    def start_profile(self, profile_prefix: str | None = None) -> None:
        """Start profiling with optional custom trace prefix.

        Args:
            profile_prefix: Optional prefix for the trace file names. If provided,
                           trace files will be named as "<prefix>_dp<X>_pp<Y>_tp<Z>".
                           If not provided, default naming will be used.
        """
        self.llm_engine.start_profile(profile_prefix)
1743
1744
1745
1746

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1747
1748
1749
1750
1751
1752
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1753

1754
    def sleep(self, level: int = 1, mode: PauseMode = "abort"):
1755
1756
1757
1758
1759
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1760
        Args:
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
            level: The sleep level.
                - Level 0: Pause scheduling but continue accepting requests.
                           Requests are queued but not processed.
                - Level 1: Offload model weights to CPU, discard KV cache.
                           The content of kv cache is forgotten. Good for
                           sleeping and waking up the engine to run the same
                           model again. Please make sure there's enough CPU
                           memory to store the model weights.
                - Level 2: Discard all GPU memory (weights + KV cache).
                           Good for sleeping and waking up the engine to run
                           a different model or update the model, where
                           previous model weights are not needed. It reduces
                           CPU memory pressure.
1774
1775
            mode: How to handle any existing requests, can be "abort", "wait",
                or "keep".
1776
        """
1777
        self.llm_engine.sleep(level=level, mode=mode)
1778

1779
    def wake_up(self, tags: list[str] | None = None):
1780
        """
1781
1782
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1783

1784
        Args:
1785
1786
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1787
1788
1789
1790
                `("weights", "kv_cache", "scheduling")`. If None, all memory
                is reallocated. wake_up should be called with all tags
                (or None) before the engine is used again.
                Use tags=["scheduling"] to resume from level 0 sleep.
1791
1792
        """
        self.llm_engine.wake_up(tags)
1793

1794
1795
1796
1797
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1798
            A `MetricSnapshot` instance capturing the current state
1799
1800
1801
1802
1803
1804
1805
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1806
    def _params_to_seq(
1807
        self,
1808
        params: _P | Sequence[_P],
1809
        num_requests: int,
1810
    ) -> Sequence[_P]:
1811
1812
1813
1814
        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({params}) "
1815
                    f"and params ({len(params)}) must be the same."
1816
1817
                )

1818
            return params
1819

1820
1821
1822
1823
1824
1825
1826
        return [params] * num_requests

    def _lora_request_to_seq(
        self,
        lora_request: LoRARequest | None | Sequence[LoRARequest | None],
        num_requests: int,
    ) -> Sequence[LoRARequest | None]:
1827
1828
1829
1830
1831
1832
1833
        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

1834
1835
1836
            return lora_request

        return [lora_request] * num_requests
1837

1838
1839
1840
1841
1842
    def _priority_to_seq(
        self,
        priority: list[int] | None,
        num_requests: int,
    ) -> Sequence[int]:
1843
1844
1845
1846
1847
1848
1849
        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )

1850
1851
1852
1853
            return priority

        return [0] * num_requests

1854
    def _add_completion_requests(
1855
1856
1857
1858
1859
1860
1861
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1862
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1863
1864
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
1865
    ) -> list[str]:
1866
1867
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(params, len(seq_prompts))
1868
1869
1870
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
        seq_priority = self._priority_to_seq(priority, len(prompts))

1871
        return self._render_and_add_requests(
1872
            prompts=(
1873
1874
1875
1876
1877
                self._preprocess_cmpl_one(prompt, tokenization_kwargs)
                for prompt in maybe_tqdm(
                    seq_prompts,
                    use_tqdm=use_tqdm,
                    desc="Rendering prompts",
1878
                )
1879
            ),
1880
            params=seq_params,
1881
1882
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
1883
1884
        )

1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
    def _run_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        output_type: type[_O],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
    ):
        self._add_completion_requests(
            prompts=prompts,
            params=params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            priority=priority,
            tokenization_kwargs=tokenization_kwargs,
        )
        return self._run_engine(use_tqdm=use_tqdm, output_type=output_type)

1908
1909
1910
1911
1912
1913
1914
    def _run_chat(
        self,
        messages: list[ChatCompletionMessageParam]
        | Sequence[list[ChatCompletionMessageParam]],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
1915
        output_type: type[_O],
1916
1917
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1918
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1919
1920
1921
1922
1923
1924
1925
1926
1927
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ):
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
        seq_convs = conversation_to_seq(messages)
        seq_params = self._params_to_seq(params, len(seq_convs))
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))

        return self._render_and_run_requests(
            prompts=(
                self._preprocess_chat_one(
                    conversation,
                    chat_template=chat_template,
                    chat_template_content_format=chat_template_content_format,
                    chat_template_kwargs=chat_template_kwargs,
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
1942
                    tokenization_kwargs=tokenization_kwargs,
1943
1944
                    mm_processor_kwargs=mm_processor_kwargs,
                )
1945
1946
1947
1948
                for conversation in maybe_tqdm(
                    seq_convs,
                    use_tqdm=use_tqdm,
                    desc="Rendering conversations",
1949
1950
1951
                )
            ),
            params=seq_params,
1952
            output_type=output_type,
1953
1954
            lora_requests=seq_lora_requests,
            use_tqdm=use_tqdm,
1955
1956
        )

1957
1958
1959
1960
    def _render_and_run_requests(
        self,
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1961
        output_type: type[_O],
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
        *,
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ):
        if isinstance(prompts, (list, tuple)):
            logger.warning_once(
                "Rendering all prompts before adding them to the engine "
                "is less efficient than performing both on the same prompt "
                "before processing the next prompt. You should instead pass "
                "a generator that renders one prompt per iteration, as that allows "
                "engine execution to begin for the first prompt while processing "
                "the next prompt."
            )

        self._render_and_add_requests(
            prompts=prompts,
1979
            params=params,
1980
1981
            lora_requests=lora_requests,
            priorities=priorities,
1982
1983
        )

1984
        return self._run_engine(output_type, use_tqdm=use_tqdm)
1985

1986
    def _render_and_add_requests(
1987
        self,
1988
1989
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1990
        *,
1991
1992
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
1993
    ) -> list[str]:
1994
        added_request_ids: list[str] = []
1995

1996
        try:
1997
            for i, prompt in enumerate(prompts):
1998
1999
                request_id = self._add_request(
                    prompt,
2000
                    params[i],
Cyrus Leung's avatar
Cyrus Leung committed
2001
2002
2003
2004
                    lora_request=self._resolve_mm_lora(
                        prompt,
                        None if lora_requests is None else lora_requests[i],
                    ),
2005
                    priority=0 if priorities is None else priorities[i],
2006
2007
2008
2009
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
2010
                self.llm_engine.abort_request(added_request_ids, internal=True)
2011
            raise e
2012

2013
2014
        return added_request_ids

2015
    def _add_request(
nunjunj's avatar
nunjunj committed
2016
        self,
2017
        prompt: ProcessorInputs,
2018
2019
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
2020
        priority: int = 0,
2021
    ) -> str:
2022
2023
2024
2025
        if isinstance(params, SamplingParams):
            # We only care about the final output
            params.output_kind = RequestOutputKind.FINAL_ONLY

2026
        request_id = str(next(self.request_counter))
2027

2028
        return self.llm_engine.add_request(
2029
            request_id,
2030
            prompt,
2031
2032
            params,
            lora_request=lora_request,
2033
            priority=priority,
nunjunj's avatar
nunjunj committed
2034
        )
2035

2036
    def _run_engine(
2037
        self,
2038
        output_type: type[_O] | tuple[type[_O], ...],
2039
2040
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
2041
    ) -> list[_O]:
2042
2043
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
2044
            num_requests = self.llm_engine.get_num_unfinished_requests()
2045
2046
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
2047
2048
2049
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
2050
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
2051
            )
2052

Zhuohan Li's avatar
Zhuohan Li committed
2053
        # Run the engine.
2054
        outputs: list[_O] = []
2055
2056
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
2057
2058
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
2059
            for output in step_outputs:
2060
                assert isinstance(output, output_type)
2061
                if output.finished:
2062
                    outputs.append(output)  # type: ignore[arg-type]
2063
                    if use_tqdm:
2064
2065
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
2066
                            n = len(output.outputs)
2067
                            assert output.prompt_token_ids is not None
2068
                            total_in_toks += len(output.prompt_token_ids) * n
2069
2070
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
2071
2072
2073
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
2074
2075
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
2076
2077
                                f"output: {out_spd:.2f} toks/s"
                            )
2078
                            pbar.update(n)
2079
2080
                        else:
                            pbar.update(1)
2081
2082
                        if pbar.n == num_requests:
                            pbar.refresh()
2083

2084
2085
        if use_tqdm:
            pbar.close()
2086
2087
2088
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
2089
        return sorted(outputs, key=lambda x: int(x.request_id))
2090

2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
    def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest | dict
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        init_info_dict = (
            request["init_info"] if isinstance(request, dict) else request.init_info
        )

        self.llm_engine.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
        """
        Update the weights of the model.

        Args:
            request: Weight update request with backend-specific update info
        """
        update_info_dict = (
            request["update_info"] if isinstance(request, dict) else request.update_info
        )

        self.llm_engine.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )

2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr