"docs/vscode:/vscode.git/clone" did not exist on "42135d689830c0e764d925b6454bc68ba6c6cab4"
llm.py 74.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
6
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar
13

14
15
16
17
18
19
20
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
21
    AttentionConfig,
22
    CompilationConfig,
23
    PoolerConfig,
24
    ProfilerConfig,
25
26
27
    StructuredOutputsConfig,
    is_init_field,
)
28
from vllm.config.compilation import CompilationMode
29
from vllm.config.model import (
30
31
    ConvertOption,
    HfOverrides,
32
    ModelDType,
33
    RunnerOption,
34
    TokenizerMode,
35
)
36
from vllm.engine.arg_utils import EngineArgs
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages,
    resolve_chat_template_content_format,
)
from vllm.entrypoints.score_utils import (
    ScoreContentPartParam,
    ScoreMultiModalParam,
    _cosine_similarity,
    _validate_score_input_lens,
    compress_token_type_ids,
    get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
from vllm.inputs import (
    DataPrompt,
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
61
from vllm.inputs.parse import get_prompt_components
62
from vllm.logger import init_logger
63
from vllm.lora.request import LoRARequest
64
from vllm.model_executor.layers.quantization import QuantizationMethods
65
66
67
68
69
70
71
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
72
from vllm.platforms import current_platform
73
from vllm.pooling_params import PoolingParams
74
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
75
from vllm.tasks import PoolingTask
76
77
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
yhu422's avatar
yhu422 committed
78
from vllm.usage.usage_lib import UsageContext
79
from vllm.utils.collection_utils import as_iter, is_list_of
80
from vllm.utils.counter import Counter
81
from vllm.v1.engine import EngineCoreRequest
82
from vllm.v1.engine.llm_engine import LLMEngine
83
from vllm.v1.sample.logits_processor import LogitsProcessor
84

85
86
87
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

88
89
logger = init_logger(__name__)

90
91
_R = TypeVar("_R", default=Any)

92
93

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
97
98
99
100
101
102
103
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
104
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
105
106
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
107
108
109
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
110
111
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
112
113
114
115
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
116
        allowed_media_domains: If set, only media URLs that belong to this
117
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
120
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
121
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
122
123
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
124
        quantization: The method used to quantize the model weights. Currently,
125
            we support "awq", "gptq", and "fp8" (experimental).
126
127
128
129
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
130
131
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
132
133
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
134
135
136
137
138
139
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
140
141
142
143
144
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
145
            compared with using gpu_memory_utilization. Note that
146
147
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
148
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
149
150
151
152
153
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
154
155
156
157
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
158
159
160
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
161
162
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
163
        hf_token: The token to use as HTTP bearer authorization for remote files
164
            . If `True`, will use the token generated when running
165
            `huggingface-cli login` (stored in `~/.huggingface`).
166
167
168
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
169
170
171
172
173
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
174
175
        pooler_config: Initialize non-default pooling config for the pooling
            model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
176
        compilation_config: Either an integer or a dictionary. If it is an
177
            integer, it is used as the mode of compilation optimization. If it
178
            is a dictionary, it can specify the full compilation configuration.
179
180
181
182
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
183
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
184

185
186
    Note:
        This class is intended to be used for offline inference. For online
187
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
188
    """
189
190
191
192

    def __init__(
        self,
        model: str,
193
        *,
194
195
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
196
        tokenizer: str | None = None,
197
        tokenizer_mode: TokenizerMode | str = "auto",
198
        skip_tokenizer_init: bool = False,
199
        trust_remote_code: bool = False,
200
        allowed_local_media_path: str = "",
201
        allowed_media_domains: list[str] | None = None,
202
        tensor_parallel_size: int = 1,
203
        dtype: ModelDType = "auto",
204
205
206
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
207
        seed: int = 0,
208
        gpu_memory_utilization: float = 0.9,
209
        swap_space: float = 4,
210
        cpu_offload_gb: float = 0,
211
        enforce_eager: bool = False,
212
        disable_custom_all_reduce: bool = False,
213
214
215
216
217
218
219
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
220
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
221
        attention_config: dict[str, Any] | AttentionConfig | None = None,
222
223
224
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
225
        **kwargs: Any,
226
    ) -> None:
227
        """LLM constructor."""
228

229
230
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
231

232
233
234
235
236
237
238
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

239
        if "kv_transfer_config" in kwargs and isinstance(
240
241
            kwargs["kv_transfer_config"], dict
        ):
242
            from vllm.config.kv_transfer import KVTransferConfig
243

244
245
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
246
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
247
248
249
250
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
251
252
253
                    raw_config_dict,
                    e,
                )
254
255
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
256
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
257

258
259
260
        if hf_overrides is None:
            hf_overrides = {}

261
262
263
264
265
266
267
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
268

269
270
271
272
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
273
        else:
274
275
276
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
277

278
279
280
281
282
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
283

284
        # warn about single-process data parallel usage.
285
286
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
287
288
289
290
291
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
292
            raise ValueError(
293
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
294
295
296
297
298
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
299
        engine_args = EngineArgs(
300
            model=model,
301
302
            runner=runner,
            convert=convert,
303
            tokenizer=tokenizer,
304
            tokenizer_mode=tokenizer_mode,
305
            skip_tokenizer_init=skip_tokenizer_init,
306
            trust_remote_code=trust_remote_code,
307
            allowed_local_media_path=allowed_local_media_path,
308
            allowed_media_domains=allowed_media_domains,
309
310
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
311
            quantization=quantization,
312
            revision=revision,
313
            tokenizer_revision=tokenizer_revision,
314
315
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
316
            kv_cache_memory_bytes=kv_cache_memory_bytes,
317
            swap_space=swap_space,
318
            cpu_offload_gb=cpu_offload_gb,
319
            enforce_eager=enforce_eager,
320
            disable_custom_all_reduce=disable_custom_all_reduce,
321
            hf_token=hf_token,
322
            hf_overrides=hf_overrides,
323
            mm_processor_kwargs=mm_processor_kwargs,
324
            pooler_config=pooler_config,
325
            structured_outputs_config=structured_outputs_instance,
326
            profiler_config=profiler_config_instance,
327
            attention_config=attention_config_instance,
328
            compilation_config=compilation_config_instance,
329
            logits_processors=logits_processors,
330
331
            **kwargs,
        )
332

333
334
        log_non_default_args(engine_args)

335
        self.llm_engine = LLMEngine.from_engine_args(
336
337
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
338
        self.engine_class = type(self.llm_engine)
339

340
        self.request_counter = Counter()
341
        self.default_sampling_params: dict[str, Any] | None = None
342

343
344
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
345
346
        self.supported_tasks = supported_tasks

347
        self.model_config = self.llm_engine.model_config
348
        self.input_processor = self.llm_engine.input_processor
349
        self.io_processor = self.llm_engine.io_processor
350

351
352
353
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

354
    def get_tokenizer(self) -> TokenizerLike:
355
        return self.llm_engine.get_tokenizer()
356

357
    def reset_mm_cache(self) -> None:
358
        self.input_processor.clear_mm_cache()
359
360
        self.llm_engine.reset_mm_cache()

361
    def get_default_sampling_params(self) -> SamplingParams:
362
        if self.default_sampling_params is None:
363
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
364
365
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
366
367
        return SamplingParams()

368
369
    def generate(
        self,
370
371
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
372
        *,
373
374
375
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
376
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
377
378
        """Generates the completions for the input prompts.

379
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
380
381
382
383
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
384
            prompts: The prompts to the LLM. You may pass a sequence of prompts
385
                for batch inference. See [PromptType][vllm.inputs.PromptType]
386
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
387
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
388
389
390
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
391
                prompts and it is paired one by one with the prompt.
392
393
394
395
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
396
            lora_request: LoRA request to use for generation, if any.
397
398
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
399
400
401
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
Woosuk Kwon's avatar
Woosuk Kwon committed
402
403

        Returns:
404
            A list of `RequestOutput` objects containing the
405
            generated completions in the same order as the input prompts.
406

407
408
409
410
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
411
        """
412
        model_config = self.model_config
413
414
        runner_type = model_config.runner_type
        if runner_type != "generate":
415
416
417
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
418
419
                "generative model."
            )
420

421
422
        if sampling_params is None:
            # Use default sampling params.
423
            sampling_params = self.get_default_sampling_params()
424

425
        # Add any modality specific loras to the corresponding prompts
426
        lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
427

428
        self._validate_and_add_requests(
429
            prompts=prompts,
430
            params=sampling_params,
431
            use_tqdm=use_tqdm,
432
            lora_request=lora_request,
433
434
            priority=priority,
        )
435

436
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
437
        return self.engine_class.validate_outputs(outputs, RequestOutput)
438

439
    def _get_modality_specific_lora_reqs(
440
        self,
441
442
        prompts: PromptType | Sequence[PromptType],
        lora_request: list[LoRARequest] | LoRARequest | None,
443
    ):
444
445
446
447
448
449
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
450
451
        if (
            lora_config is None
452
            or not self.model_config.is_multimodal_model
453
454
            or (lora_config and lora_config.default_mm_loras is None)
        ):
455
456
            return lora_request

457
        if not isinstance(prompts, Sequence) or isinstance(prompts, str):
458
            prompts = [prompts]
459

460
461
462
463
464
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
465
466
467

        return [
            self._resolve_single_prompt_mm_lora(
468
                prompt,
469
470
                opt_lora_req,
                lora_config.default_mm_loras,
471
472
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
473
474
        ]

475
476
477
    def _resolve_single_prompt_mm_lora(
        self,
        prompt: PromptType,
478
479
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
480
481
482
483
    ):
        if (
            not default_mm_loras
            or not isinstance(prompt, dict)
484
            or not (mm_data := prompt.get("multi_modal_data") or {})
485
        ):
486
487
            return lora_request

488
489
490
        intersection = set(
            mm_data.keys()  # type: ignore
        ).intersection(default_mm_loras.keys())
491
492
493
494
495
496
497
498
499
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
500
501
502
                " will be skipped",
                intersection,
            )
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
518
519
                    "lora_request as we only apply one LoRARequest per prompt"
                )
520
521
522
523
524
525
526
527
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

528
529
    def collective_rpc(
        self,
530
531
        method: str | Callable[..., _R],
        timeout: float | None = None,
532
        args: tuple = (),
533
        kwargs: dict[str, Any] | None = None,
534
    ) -> list[_R]:
535
536
537
538
539
540
541
542
543
544
545
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
546
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
547
548
549
550
551
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
552

553
554
555
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
556
        """
557
558

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
559
560

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
561
        """
562
563
        Run a function directly on the model inside each worker,
        returning the result for each of them.
564
565
566
567
568
569

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
570
        """
571
        return self.llm_engine.apply_model(func)
572

573
574
    def _get_beam_search_lora_requests(
        self,
575
576
577
        lora_request: list[LoRARequest] | LoRARequest | None,
        prompts: list[TokensPrompt | TextPrompt],
    ) -> list[LoRARequest | None]:
578
        """Get the optional lora request corresponding to each prompt."""
579
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
580
            raise ValueError(
581
582
                "Lora request list should be the same length as the prompts"
            )
583
584
585
586
587
588

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

589
590
    def beam_search(
        self,
591
        prompts: list[TokensPrompt | TextPrompt],
592
        params: BeamSearchParams,
593
        lora_request: list[LoRARequest] | LoRARequest | None = None,
594
        use_tqdm: bool = False,
595
        concurrency_limit: int | None = None,
596
    ) -> list[BeamSearchOutput]:
597
598
599
600
601
602
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
603
            params: The beam search parameters.
604
            lora_request: LoRA request to use for generation, if any.
605
            use_tqdm: Whether to use tqdm to display the progress bar.
606
607
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
608
        """
609
610
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
611
612
613
614
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
615
616
        length_penalty = params.length_penalty

617
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
618

619
620
621
622
623
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
624

625
626
627
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
628
629
                "Disabling progress bar."
            )
630
631
632
633
634
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

635
636
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
637
638
639
640
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
641
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
642
            return TokensPrompt(**token_prompt_kwargs)
643

644
645
646
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
647
        beam_search_params = SamplingParams(
648
649
650
651
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
652
        )
653
        instances: list[BeamSearchInstance] = []
654

655
        for lora_req, prompt in zip(lora_requests, prompts):
656
657
658
659
660
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
661
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
662

663
664
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
665
666
667
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
668

669
            instances.append(
670
671
672
673
674
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
675
676
                ),
            )
677

678
        for prompt_start in range(0, len(prompts), concurrency_limit):
679
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
680
681
682

            token_iter = range(max_tokens)
            if use_tqdm:
683
684
685
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
686
687
688
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
689
690
                    "reflect instance-level progress."
                )
691
692
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
693
694
                    sum((instance.beams for instance in instances_batch), [])
                )
695
696
                pos = [0] + list(
                    itertools.accumulate(
697
698
699
                        len(instance.beams) for instance in instances_batch
                    )
                )
700
                instance_start_and_end: list[tuple[int, int]] = list(
701
702
                    zip(pos[:-1], pos[1:])
                )
703
704
705
706
707
708

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
709
710
711
712
713
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
714
715
716

                # only runs for one step
                # we don't need to use tqdm here
717
718
719
720
721
722
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
723

724
725
726
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
727
728
729
730
731
732
733
734
735
736
737
738
739
740
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
741
                                    logprobs=current_beam.logprobs + [logprobs],
742
                                    lora_request=current_beam.lora_request,
743
744
745
746
747
748
749
750
751
752
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
753
754
755
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
756
757
758
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
759
                    instance.beams = sorted_beams[:beam_width]
760
761
762
763

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
764
765
766
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
767
768
769
770
771
772
773
774
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

775
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
776
        self,
777
778
779
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        chat_template: str | None = None,
780
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
781
        add_generation_prompt: bool = True,
782
        continue_final_message: bool = False,
783
784
785
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
786
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
787
        """
788
789
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
790

791
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
792
        Returns:
793
794
795
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
796
        """
797
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
798

799
800
        # Handle multi and single conversations
        if is_list_of(messages, list):
801
            # messages is list[list[...]]
802
            list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
803
        else:
804
            # messages is list[...]
805
            list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
806

807
        tokenizer = self.get_tokenizer()
808
        model_config = self.model_config
809
810
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
811
            tools,
812
813
            chat_template_content_format,
            tokenizer,
814
            model_config=model_config,
815
816
        )

817
818
819
820
821
822
823
824
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

825
        prompts: list[TokensPrompt] = []
826
827

        for msgs in list_of_messages:
828
829
830
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
831
            conversation, mm_data, mm_uuids = parse_chat_messages(
832
                msgs,
833
                model_config,
834
835
                content_format=resolved_content_format,
            )
836
837

            if isinstance(tokenizer, MistralTokenizer):
838
                prompt_token_ids = apply_mistral_chat_template(
839
840
                    tokenizer,
                    messages=msgs,
841
                    **_chat_template_kwargs,
842
843
                )
            else:
844
                prompt_str = apply_hf_chat_template(
845
                    tokenizer=tokenizer,
846
                    conversation=conversation,
847
                    model_config=model_config,
848
                    **_chat_template_kwargs,
849
                )
850
851
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
852
853
854
                prompt_token_ids = tokenizer.encode(
                    prompt_str, add_special_tokens=False
                )
855

856
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
857
858
859
860

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

861
862
863
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

864
865
866
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

867
            prompts.append(prompt)
868

869
870
871
872
        return prompts

    def chat(
        self,
873
874
875
876
877
878
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | list[SamplingParams] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
879
880
881
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
882
883
884
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
948
        return self.generate(
949
            prompts,
950
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
951
952
953
954
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

955
956
    def encode(
        self,
957
958
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
959
        *,
960
961
962
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
963
        pooling_task: PoolingTask | None = None,
964
        tokenization_kwargs: dict[str, Any] | None = None,
965
    ) -> list[PoolingRequestOutput]:
966
967
        """Apply pooling to the hidden states corresponding to the input
        prompts.
968

969
        This class automatically batches the given prompts, considering
970
971
972
973
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
974
            prompts: The prompts to the LLM. You may pass a sequence of prompts
975
                for batch inference. See [PromptType][vllm.inputs.PromptType]
976
                for more details about the format of each prompt.
977
978
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
979
980
981
982
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
983
            lora_request: LoRA request to use for generation, if any.
984
            pooling_task: Override the pooling task to use.
985
986
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
987
988

        Returns:
989
            A list of `PoolingRequestOutput` objects containing the
990
            pooled hidden states in the same order as the input prompts.
991

992
993
994
995
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
996
        """
997

998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
        error_str = (
            "pooling_task required for `LLM.encode`\n"
            "Please use one of the more specific methods or set the "
            "pooling_task when using `LLM.encode`:\n"
            "  - For embeddings, use `LLM.embed(...)` "
            'or `pooling_task="embed"`.\n'
            "  - For classification logits, use `LLM.classify(...)` "
            'or `pooling_task="classify"`.\n'
            "  - For similarity scores, use `LLM.score(...)`.\n"
            "  - For rewards, use `LLM.reward(...)` "
            'or `pooling_task="token_classify"`\n'
            "  - For token classification, "
            'use `pooling_task="token_classify"`\n'
            '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
        )
1013

1014
        if pooling_task is None:
1015
            raise ValueError(error_str)
1016

1017
        model_config = self.model_config
1018
        runner_type = model_config.runner_type
1019
        if runner_type != "pooling":
1020
1021
1022
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1023
1024
                "pooling model."
            )
1025

1026
1027
1028
1029
1030
1031
1032
1033
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1034
1035
                    "offline inference example for more details."
                )
1036
1037
1038
1039
1040
1041
1042

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
        if io_processor_prompt:
            assert self.io_processor is not None
            if is_list_of(pooling_params, PoolingParams):
                validated_pooling_params: list[PoolingParams] = []
                for param in as_iter(pooling_params):
                    validated_pooling_params.append(
                        self.io_processor.validate_or_generate_params(param)
                    )
                pooling_params = validated_pooling_params
            else:
                assert not isinstance(pooling_params, Sequence)
                pooling_params = self.io_processor.validate_or_generate_params(
                    pooling_params
                )
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

        if pooling_task not in self.supported_tasks:
            raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens

1071
        self._validate_and_add_requests(
1072
            prompts=prompts,
1073
            params=pooling_params,
1074
            use_tqdm=use_tqdm,
1075
            lora_request=lora_request,
1076
            tokenization_kwargs=tokenization_kwargs,
1077
1078
        )

1079
        outputs = self._run_engine(use_tqdm=use_tqdm)
1080
1081

        model_outputs = self.engine_class.validate_outputs(
1082
1083
            outputs, PoolingRequestOutput
        )
1084
1085
1086
1087
1088

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
1089
1090
                model_output=model_outputs
            )
1091
1092

            return [
1093
1094
1095
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1096
1097
1098
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1099
1100
1101
                    prompt_token_ids=[],
                    finished=True,
                )
1102
1103
1104
            ]
        else:
            return model_outputs
1105

1106
1107
    def embed(
        self,
1108
        prompts: PromptType | Sequence[PromptType],
1109
        *,
1110
1111
1112
1113
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1114
        tokenization_kwargs: dict[str, Any] | None = None,
1115
    ) -> list[EmbeddingRequestOutput]:
1116
1117
1118
1119
1120
1121
1122
1123
1124
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1125
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1126
                for more details about the format of each prompt.
1127
1128
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1129
1130
1131
1132
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1133
1134
1135
            lora_request: LoRA request to use for generation, if any.

        Returns:
1136
            A list of `EmbeddingRequestOutput` objects containing the
1137
1138
            embedding vectors in the same order as the input prompts.
        """
1139
        if "embed" not in self.supported_tasks:
1140
1141
            raise ValueError(
                "Embedding API is not supported by this model. "
1142
1143
                "Try converting the model using `--convert embed`."
            )
1144

1145
1146
1147
1148
1149
1150
1151
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1152
            tokenization_kwargs=tokenization_kwargs,
1153
        )
1154
1155
1156
1157
1158

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1159
        prompts: PromptType | Sequence[PromptType],
1160
        *,
1161
1162
1163
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1164
        tokenization_kwargs: dict[str, Any] | None = None,
1165
    ) -> list[ClassificationRequestOutput]:
1166
1167
1168
1169
1170
1171
1172
1173
1174
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1175
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1176
                for more details about the format of each prompt.
1177
1178
1179
1180
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1181
            lora_request: LoRA request to use for generation, if any.
1182
1183
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1184
        Returns:
1185
            A list of `ClassificationRequestOutput` objects containing the
1186
1187
            embedding vectors in the same order as the input prompts.
        """
1188
        if "classify" not in self.supported_tasks:
1189
            raise ValueError(
1190
                "Classification API is not supported by this model. "
1191
1192
                "Try converting the model using `--convert classify`."
            )
1193

1194
1195
1196
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1197
            pooling_params=pooling_params,
1198
1199
            lora_request=lora_request,
            pooling_task="classify",
1200
            tokenization_kwargs=tokenization_kwargs,
1201
        )
1202
1203
1204

        return [ClassificationRequestOutput.from_base(item) for item in items]

1205
1206
    def reward(
        self,
1207
        prompts: PromptType | Sequence[PromptType],
1208
1209
        /,
        *,
1210
1211
1212
1213
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1214
        tokenization_kwargs: dict[str, Any] | None = None,
1215
1216
1217
1218
1219
1220
1221
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1222
                for more details about the format of each prompt.
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1241
            pooling_task="token_classify",
1242
            tokenization_kwargs=tokenization_kwargs,
1243
1244
        )

1245
1246
    def _embedding_score(
        self,
1247
        tokenizer: TokenizerLike,
1248
1249
1250
1251
1252
1253
        text_1: list[str | TextPrompt | TokensPrompt],
        text_2: list[str | TextPrompt | TokensPrompt],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1254
        tokenization_kwargs: dict[str, Any] | None = None,
1255
1256
    ) -> list[ScoringRequestOutput]:
        encoded_output: list[PoolingRequestOutput] = self.encode(
1257
            text_1 + text_2,
1258
            truncate_prompt_tokens=truncate_prompt_tokens,
1259
1260
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1261
            pooling_params=pooling_params,
1262
            pooling_task="embed",
1263
            tokenization_kwargs=tokenization_kwargs,
1264
        )
1265

1266
1267
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
1268
1269
1270
1271

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1272
1273
1274
        scores = _cosine_similarity(
            tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2
        )
1275

1276
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1277
1278
1279
1280
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1281
        tokenizer: TokenizerLike,
1282
1283
1284
1285
1286
1287
        data_1: list[str] | list[ScoreContentPartParam],
        data_2: list[str] | list[ScoreContentPartParam],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1288
        tokenization_kwargs: dict[str, Any] | None = None,
1289
        score_template: str | None = None,
1290
    ) -> list[ScoringRequestOutput]:
1291
        model_config = self.model_config
1292
1293

        if isinstance(tokenizer, MistralTokenizer):
1294
            raise ValueError("Score API is not supported for Mistral tokenizer")
1295

1296
1297
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1298

1299
1300
1301
1302
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        pooling_params.verify("score", model_config)
1303
        pooling_params_list = list[PoolingParams]()
1304

1305
1306
        local_kwargs = tokenization_kwargs or {}
        tokenization_kwargs = local_kwargs.copy()
1307

1308
1309
1310
        _validate_truncation_size(
            model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
        )
1311

1312
        prompts = list[PromptType]()
1313

1314
1315
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1316
1317
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1318
                model_config=model_config,
1319
1320
1321
1322
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1323
                score_template=score_template,
1324
1325
            )

1326
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1327
1328
1329
1330
1331
1332
1333
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1334
            prompts.append(engine_prompt)
1335
1336

        self._validate_and_add_requests(
1337
            prompts=prompts,
1338
            params=pooling_params_list,
1339
            use_tqdm=use_tqdm,
1340
1341
1342
1343
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1344
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1345
1346
1347

        return [ScoringRequestOutput.from_base(item) for item in items]

1348
1349
    def score(
        self,
1350
1351
        data_1: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
        data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
1352
        /,
1353
        *,
1354
1355
1356
1357
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1358
        chat_template: str | None = None,
1359
    ) -> list[ScoringRequestOutput]:
1360
1361
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1362

1363
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1364
1365
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1366
        The input pairs are used to build a list of prompts for the
1367
1368
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1369
1370
1371
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1372
        appropriate multi-modal models. For multi-modal inputs, ensure the
1373
        prompt structure matches the model's expected input format.
1374
1375

        Args:
1376
1377
1378
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1379
                the `data_2` list.
1380
            data_2: The data to pair with the query to form the input to
1381
                the LLM. Can be text or multi-modal data. See [PromptType]
1382
                [vllm.inputs.PromptType] for more details about the format of
1383
                each prompt.
1384
1385
1386
1387
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1388
            lora_request: LoRA request to use for generation, if any.
1389
1390
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1391
1392
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1393
        Returns:
1394
            A list of `ScoringRequestOutput` objects containing the
1395
1396
            generated scores in the same order as the input prompts.
        """
1397
        model_config = self.model_config
1398
        runner_type = model_config.runner_type
1399
        if runner_type != "pooling":
1400
1401
1402
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1403
1404
                "pooling model."
            )
1405

1406
1407
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1408
1409
1410
1411
1412
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1413

1414
1415
1416
1417
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1418
            raise ValueError("Score API is only enabled for num_labels == 1.")
1419

1420
1421
1422
1423
1424
        if not model_config.is_cross_encoder and chat_template is not None:
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1425
1426
1427
        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1428
        tokenizer = self.get_tokenizer()
1429

1430
        if not model_config.is_multimodal_model:
1431

1432
            def check_data_type(
1433
1434
1435
                data: SingletonPrompt
                | Sequence[SingletonPrompt]
                | ScoreMultiModalParam,
1436
            ):
1437
                if isinstance(data, dict) and "content" in data:
1438
1439
1440
1441
                    raise ValueError(
                        "ScoreMultiModalParam is not supported "
                        f"for {model_config.architecture}"
                    )
1442
1443
1444
1445
1446
1447
1448

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
1449
1450
1451
                        raise ValueError(
                            "Multi-modal prompt is not supported for scoring"
                        )
1452
1453
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
1454
1455
                            cast(TokensPrompt, prompt)["prompt_token_ids"]
                        )
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1484

1485
        if model_config.is_cross_encoder:
1486
1487
1488
1489
1490
1491
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1492
                pooling_params,
1493
                lora_request,
1494
                score_template=chat_template,
1495
            )
1496
        else:
1497
1498
            return self._embedding_score(
                tokenizer,
1499
1500
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1501
1502
                truncate_prompt_tokens,
                use_tqdm,
1503
                pooling_params,
1504
1505
                lora_request,
            )
1506

1507
1508
1509
1510
1511
1512
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1513
1514
1515
1516
1517
1518
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1519

1520
1521
1522
1523
1524
1525
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1526
        Args:
1527
1528
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1529
                is forgotten. Level 1 sleep is good for sleeping and waking
1530
1531
1532
1533
1534
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1535
                sleep is good for sleeping and waking up the engine to run a
1536
                different model or update the model, where previous model
1537
                weights are not needed. It reduces CPU memory pressure.
1538
        """
1539
        self.reset_prefix_cache()
1540
1541
        self.llm_engine.sleep(level=level)

1542
    def wake_up(self, tags: list[str] | None = None):
1543
        """
1544
1545
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1546

1547
        Args:
1548
1549
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1550
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1551
                wake_up should be called with all tags (or None) before the
1552
1553
1554
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1555

1556
1557
1558
1559
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1560
            A `MetricSnapshot` instance capturing the current state
1561
1562
1563
1564
1565
1566
1567
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1568
1569
    def _validate_and_add_requests(
        self,
1570
1571
1572
1573
1574
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        params: SamplingParams
        | Sequence[SamplingParams]
        | PoolingParams
        | Sequence[PoolingParams],
1575
        *,
1576
1577
1578
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None,
        priority: list[int] | None = None,
1579
        tokenization_kwargs: dict[str, Any] | None = None,
1580
    ) -> None:
1581
        if isinstance(prompts, (str, dict)):
1582
            # Convert a single prompt to a list.
1583
            prompts = [prompts]  # type: ignore[list-item]
1584

1585
        num_requests = len(prompts)
1586
        if isinstance(params, Sequence) and len(params) != num_requests:
1587
1588
1589
1590
1591
            raise ValueError("The lengths of prompts and params must be the same.")
        if isinstance(lora_request, Sequence) and len(lora_request) != num_requests:
            raise ValueError(
                "The lengths of prompts and lora_request must be the same."
            )
1592
1593
1594
1595
1596
1597
        if priority is not None and len(priority) != num_requests:
            raise ValueError(
                "The lengths of prompts "
                f"({num_requests}) and priority ({len(priority)}) "
                "must be the same."
            )
1598
1599

        for sp in params if isinstance(params, Sequence) else (params,):
1600
1601
1602
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1603

Zhuohan Li's avatar
Zhuohan Li committed
1604
        # Add requests to the engine.
1605
1606
        it = prompts
        if use_tqdm:
1607
1608
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1609

1610
        added_request_ids: list[str] = []
1611

1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
        try:
            for i, prompt in enumerate(it):
                if isinstance(prompt, dict):
                    self._validate_mm_data_and_uuids(
                        prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")
                    )
                request_id = self._add_request(
                    prompt,
                    params[i] if isinstance(params, Sequence) else params,
                    lora_request=lora_request[i]
                    if isinstance(lora_request, Sequence)
                    else lora_request,
                    priority=priority[i] if priority else 0,
1625
                    tokenization_kwargs=tokenization_kwargs,
1626
1627
1628
1629
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1630
                self.llm_engine.abort_request(added_request_ids, internal=True)
1631
            raise e
1632

1633
    def _validate_mm_data_and_uuids(
1634
        self,
1635
1636
        multi_modal_data: Any | None,  # MultiModalDataDict
        multi_modal_uuids: Any | None,  # MultiModalUUIDDict
1637
1638
1639
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1640
        then its corresponding UUID must be set.
1641
1642
1643
1644
1645
1646
1647
1648
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
1649
1650
1651
1652
1653
1654
1655
1656
                        if (
                            multi_modal_uuids is None
                            or modality not in multi_modal_uuids
                            or multi_modal_uuids[  # noqa: E501
                                modality
                            ]
                            is None
                        ):
1657
1658
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
1659
1660
                                f"but UUID is not provided"
                            )
1661
                        else:
1662
1663
1664
1665
                            if (
                                len(multi_modal_uuids[modality]) <= i
                                or multi_modal_uuids[modality][i] is None
                            ):
1666
1667
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
1668
1669
                                    f"but UUID is not provided"
                                )
1670
            else:
1671
1672
1673
1674
1675
1676
1677
1678
1679
                if data is None and (
                    multi_modal_uuids is None
                    or modality not in multi_modal_uuids
                    or multi_modal_uuids[modality] is None
                ):
                    raise ValueError(
                        f"Multi-modal data for {modality} is None"
                        f" but UUID is not provided"
                    )
1680

1681
1682
1683
1684
    def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1685
        params: SamplingParams | PoolingParams,
1686
        *,
1687
        lora_request: LoRARequest | None,
1688
        priority: int,
1689
        tokenization_kwargs: dict[str, Any] | None = None,
1690
1691
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
        """Use the Processor to process inputs for LLMEngine."""
1692
1693
1694

        local_kwargs = tokenization_kwargs or {}
        tokenization_kwargs = local_kwargs.copy()
1695
1696
1697
1698
1699
        _validate_truncation_size(
            self.model_config.max_model_len,
            params.truncate_prompt_tokens,
            tokenization_kwargs,
        )
1700

1701
        engine_request = self.input_processor.process_inputs(
1702
1703
1704
1705
1706
1707
1708
1709
1710
            request_id,
            engine_prompt,
            params,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1711
    def _add_request(
nunjunj's avatar
nunjunj committed
1712
        self,
1713
        prompt: PromptType,
1714
1715
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1716
        priority: int = 0,
1717
        tokenization_kwargs: dict[str, Any] | None = None,
1718
    ) -> str:
1719
        prompt_text, _, _ = get_prompt_components(prompt)
1720
        request_id = str(next(self.request_counter))
1721
1722

        engine_request, tokenization_kwargs = self._process_inputs(
1723
            request_id,
1724
            prompt,
1725
1726
            params,
            lora_request=lora_request,
1727
            priority=priority,
1728
            tokenization_kwargs=tokenization_kwargs,
1729
1730
1731
1732
1733
1734
1735
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1736
            tokenization_kwargs=tokenization_kwargs,
1737
            priority=priority,
1738
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1739
        )
1740
        return engine_request.request_id
1741

1742
    def _run_engine(
1743
1744
        self, *, use_tqdm: bool | Callable[..., tqdm] = True
    ) -> list[RequestOutput | PoolingRequestOutput]:
1745
1746
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1747
            num_requests = self.llm_engine.get_num_unfinished_requests()
1748
1749
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1750
1751
1752
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1753
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1754
            )
1755

Zhuohan Li's avatar
Zhuohan Li committed
1756
        # Run the engine.
1757
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1758
1759
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1760
1761
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1762
            for output in step_outputs:
1763
                if output.finished:
1764
1765
                    outputs.append(output)
                    if use_tqdm:
1766
1767
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1768
                            n = len(output.outputs)
1769
                            assert output.prompt_token_ids is not None
1770
                            total_in_toks += len(output.prompt_token_ids) * n
1771
1772
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1773
1774
1775
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1776
1777
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1778
1779
                                f"output: {out_spd:.2f} toks/s"
                            )
1780
                            pbar.update(n)
1781
1782
                        else:
                            pbar.update(1)
1783
1784
                        if pbar.n == num_requests:
                            pbar.refresh()
1785

1786
1787
        if use_tqdm:
            pbar.close()
1788
1789
1790
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1791
        return sorted(outputs, key=lambda x: int(x.request_id))
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804

    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr