llm.py 74.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
6
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar
13

14
15
16
17
18
19
20
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
21
    AttentionConfig,
22
    CompilationConfig,
23
    PoolerConfig,
24
    ProfilerConfig,
25
26
27
    StructuredOutputsConfig,
    is_init_field,
)
28
from vllm.config.compilation import CompilationMode
29
from vllm.config.model import (
30
31
    ConvertOption,
    HfOverrides,
32
    ModelDType,
33
    RunnerOption,
34
    TokenizerMode,
35
)
36
from vllm.engine.arg_utils import EngineArgs
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages,
    resolve_chat_template_content_format,
)
from vllm.entrypoints.score_utils import (
    ScoreContentPartParam,
    ScoreMultiModalParam,
    _cosine_similarity,
    _validate_score_input_lens,
    compress_token_type_ids,
    get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
from vllm.inputs import (
    DataPrompt,
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
61
from vllm.inputs.parse import get_prompt_components
62
from vllm.logger import init_logger
63
from vllm.lora.request import LoRARequest
64
from vllm.model_executor.layers.quantization import QuantizationMethods
65
66
67
68
69
70
71
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
72
from vllm.platforms import current_platform
73
from vllm.pooling_params import PoolingParams
74
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
75
from vllm.tasks import PoolingTask
76
77
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
yhu422's avatar
yhu422 committed
78
from vllm.usage.usage_lib import UsageContext
79
from vllm.utils.collection_utils import as_iter, is_list_of
80
from vllm.utils.counter import Counter
81
from vllm.v1.engine import EngineCoreRequest
82
from vllm.v1.engine.llm_engine import LLMEngine
83
from vllm.v1.sample.logits_processor import LogitsProcessor
84

85

86
87
88
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

89
90
logger = init_logger(__name__)

91
92
_R = TypeVar("_R", default=Any)

93
94

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
98
99
100
101
102
103
104
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
105
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
106
107
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
108
109
110
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
111
112
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
113
114
115
116
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
117
        allowed_media_domains: If set, only media URLs that belong to this
118
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
121
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
122
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
123
124
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
125
        quantization: The method used to quantize the model weights. Currently,
126
            we support "awq", "gptq", and "fp8" (experimental).
127
128
129
130
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
131
132
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
133
134
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
135
136
137
138
139
140
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
141
142
143
144
145
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
146
            compared with using gpu_memory_utilization. Note that
147
148
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
149
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
150
151
152
153
154
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
155
156
157
158
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
159
160
161
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
162
        enable_return_routed_experts: Whether to return routed experts.
163
164
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
165
        hf_token: The token to use as HTTP bearer authorization for remote files
166
            . If `True`, will use the token generated when running
167
            `huggingface-cli login` (stored in `~/.huggingface`).
168
169
170
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
171
172
173
174
175
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
176
        pooler_config: Initialize non-default pooling config for the pooling
177
            model. e.g. `PoolerConfig(seq_pooling_type="MEAN", normalize=False)`.
178
        compilation_config: Either an integer or a dictionary. If it is an
179
            integer, it is used as the mode of compilation optimization. If it
180
            is a dictionary, it can specify the full compilation configuration.
181
182
183
184
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
185
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
186

187
188
    Note:
        This class is intended to be used for offline inference. For online
189
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
190
    """
191
192
193
194

    def __init__(
        self,
        model: str,
195
        *,
196
197
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
198
        tokenizer: str | None = None,
199
        tokenizer_mode: TokenizerMode | str = "auto",
200
        skip_tokenizer_init: bool = False,
201
        trust_remote_code: bool = False,
202
        allowed_local_media_path: str = "",
203
        allowed_media_domains: list[str] | None = None,
204
        tensor_parallel_size: int = 1,
205
        dtype: ModelDType = "auto",
206
207
208
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
209
        seed: int = 0,
210
        gpu_memory_utilization: float = 0.9,
211
        swap_space: float = 4,
212
        cpu_offload_gb: float = 0,
213
        enforce_eager: bool = False,
214
        enable_return_routed_experts: bool = False,
215
        disable_custom_all_reduce: bool = False,
216
217
218
219
220
221
222
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
223
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
224
        attention_config: dict[str, Any] | AttentionConfig | None = None,
225
226
227
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
228
        **kwargs: Any,
229
    ) -> None:
230
        """LLM constructor."""
231

232
233
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
234

235
236
237
238
239
240
241
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

242
        if "kv_transfer_config" in kwargs and isinstance(
243
244
            kwargs["kv_transfer_config"], dict
        ):
245
            from vllm.config.kv_transfer import KVTransferConfig
246

247
248
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
249
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
250
251
252
253
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
254
255
256
                    raw_config_dict,
                    e,
                )
257
258
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
259
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
260

261
262
263
        if hf_overrides is None:
            hf_overrides = {}

264
265
266
267
268
269
270
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
271

272
273
274
275
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
276
        else:
277
278
279
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
280

281
282
283
284
285
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
286

287
        # warn about single-process data parallel usage.
288
289
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
290
291
292
293
294
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
295
            raise ValueError(
296
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
297
298
299
300
301
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
302
        engine_args = EngineArgs(
303
            model=model,
304
305
            runner=runner,
            convert=convert,
306
            tokenizer=tokenizer,
307
            tokenizer_mode=tokenizer_mode,
308
            skip_tokenizer_init=skip_tokenizer_init,
309
            trust_remote_code=trust_remote_code,
310
            allowed_local_media_path=allowed_local_media_path,
311
            allowed_media_domains=allowed_media_domains,
312
313
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
314
            quantization=quantization,
315
            revision=revision,
316
            tokenizer_revision=tokenizer_revision,
317
318
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
319
            kv_cache_memory_bytes=kv_cache_memory_bytes,
320
            swap_space=swap_space,
321
            cpu_offload_gb=cpu_offload_gb,
322
            enforce_eager=enforce_eager,
323
            enable_return_routed_experts=enable_return_routed_experts,
324
            disable_custom_all_reduce=disable_custom_all_reduce,
325
            hf_token=hf_token,
326
            hf_overrides=hf_overrides,
327
            mm_processor_kwargs=mm_processor_kwargs,
328
            pooler_config=pooler_config,
329
            structured_outputs_config=structured_outputs_instance,
330
            profiler_config=profiler_config_instance,
331
            attention_config=attention_config_instance,
332
            compilation_config=compilation_config_instance,
333
            logits_processors=logits_processors,
334
335
            **kwargs,
        )
336

337
338
        log_non_default_args(engine_args)

339
        self.llm_engine = LLMEngine.from_engine_args(
340
341
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
342
        self.engine_class = type(self.llm_engine)
343

344
        self.request_counter = Counter()
345
        self.default_sampling_params: dict[str, Any] | None = None
346

347
348
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
349
350
        self.supported_tasks = supported_tasks

351
        self.model_config = self.llm_engine.model_config
352
        self.input_processor = self.llm_engine.input_processor
353
        self.io_processor = self.llm_engine.io_processor
354

355
356
357
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

358
    def get_tokenizer(self) -> TokenizerLike:
359
        return self.llm_engine.get_tokenizer()
360

361
    def reset_mm_cache(self) -> None:
362
        self.input_processor.clear_mm_cache()
363
364
        self.llm_engine.reset_mm_cache()

365
    def get_default_sampling_params(self) -> SamplingParams:
366
        if self.default_sampling_params is None:
367
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
368
369
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
370
371
        return SamplingParams()

372
373
    def generate(
        self,
374
375
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
376
        *,
377
378
379
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
380
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
381
382
        """Generates the completions for the input prompts.

383
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
384
385
386
387
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
388
            prompts: The prompts to the LLM. You may pass a sequence of prompts
389
                for batch inference. See [PromptType][vllm.inputs.PromptType]
390
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
391
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
392
393
394
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
395
                prompts and it is paired one by one with the prompt.
396
397
398
399
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
400
            lora_request: LoRA request to use for generation, if any.
401
402
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
403
404
405
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
Woosuk Kwon's avatar
Woosuk Kwon committed
406
407

        Returns:
408
            A list of `RequestOutput` objects containing the
409
            generated completions in the same order as the input prompts.
410
411

        Note:
412
            Using `prompts` and `prompt_token_ids` as keyword parameters is
413
            considered legacy and may be deprecated in the future. You should
414
            instead pass them via the `inputs` parameter.
415
        """
416
        model_config = self.model_config
417
418
        runner_type = model_config.runner_type
        if runner_type != "generate":
419
420
421
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
422
423
                "generative model."
            )
424

425
426
        if sampling_params is None:
            # Use default sampling params.
427
            sampling_params = self.get_default_sampling_params()
428

429
        # Add any modality specific loras to the corresponding prompts
430
        lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
431

432
        self._validate_and_add_requests(
433
            prompts=prompts,
434
            params=sampling_params,
435
            use_tqdm=use_tqdm,
436
            lora_request=lora_request,
437
438
            priority=priority,
        )
439

440
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
441
        return self.engine_class.validate_outputs(outputs, RequestOutput)
442

443
    def _get_modality_specific_lora_reqs(
444
        self,
445
446
        prompts: PromptType | Sequence[PromptType],
        lora_request: list[LoRARequest] | LoRARequest | None,
447
    ):
448
449
450
451
452
453
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
454
455
        if (
            lora_config is None
456
            or not self.model_config.is_multimodal_model
457
458
            or (lora_config and lora_config.default_mm_loras is None)
        ):
459
460
            return lora_request

461
        if not isinstance(prompts, Sequence) or isinstance(prompts, str):
462
            prompts = [prompts]
463

464
465
466
467
468
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
469
470
471

        return [
            self._resolve_single_prompt_mm_lora(
472
                prompt,
473
474
                opt_lora_req,
                lora_config.default_mm_loras,
475
476
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
477
478
        ]

479
480
481
    def _resolve_single_prompt_mm_lora(
        self,
        prompt: PromptType,
482
483
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
484
485
486
487
    ):
        if (
            not default_mm_loras
            or not isinstance(prompt, dict)
488
            or not (mm_data := prompt.get("multi_modal_data") or {})
489
        ):
490
491
            return lora_request

492
493
494
        intersection = set(
            mm_data.keys()  # type: ignore
        ).intersection(default_mm_loras.keys())
495
496
497
498
499
500
501
502
503
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
504
505
506
                " will be skipped",
                intersection,
            )
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
522
523
                    "lora_request as we only apply one LoRARequest per prompt"
                )
524
525
526
527
528
529
530
531
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

532
533
    def collective_rpc(
        self,
534
535
        method: str | Callable[..., _R],
        timeout: float | None = None,
536
        args: tuple = (),
537
        kwargs: dict[str, Any] | None = None,
538
    ) -> list[_R]:
539
540
541
542
543
544
545
546
547
548
549
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
550
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
551
552
553
554
555
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
556

557
558
559
560
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
561
562

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
563
564

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
565
        """
566
567
        Run a function directly on the model inside each worker,
        returning the result for each of them.
568
569
570
571
572
573

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
574
        """
575
        return self.llm_engine.apply_model(func)
576

577
578
    def _get_beam_search_lora_requests(
        self,
579
580
581
        lora_request: list[LoRARequest] | LoRARequest | None,
        prompts: list[TokensPrompt | TextPrompt],
    ) -> list[LoRARequest | None]:
582
        """Get the optional lora request corresponding to each prompt."""
583
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
584
            raise ValueError(
585
586
                "Lora request list should be the same length as the prompts"
            )
587
588
589
590
591
592

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

593
594
    def beam_search(
        self,
595
        prompts: list[TokensPrompt | TextPrompt],
596
        params: BeamSearchParams,
597
        lora_request: list[LoRARequest] | LoRARequest | None = None,
598
        use_tqdm: bool = False,
599
        concurrency_limit: int | None = None,
600
    ) -> list[BeamSearchOutput]:
601
602
603
604
605
606
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
607
            params: The beam search parameters.
608
            lora_request: LoRA request to use for generation, if any.
609
            use_tqdm: Whether to use tqdm to display the progress bar.
610
611
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
612
        """
613
614
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
615
616
617
618
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
619
620
        length_penalty = params.length_penalty

621
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
622

623
624
625
626
627
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
628

629
630
631
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
632
633
                "Disabling progress bar."
            )
634
635
636
637
638
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

639
640
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
641
642
643
644
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
645
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
646
            return TokensPrompt(**token_prompt_kwargs)
647

648
649
650
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
651
        beam_search_params = SamplingParams(
652
653
654
655
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
656
        )
657
        instances: list[BeamSearchInstance] = []
658

659
        for lora_req, prompt in zip(lora_requests, prompts):
660
661
662
663
664
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
665
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
666

667
668
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
669
670
671
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
672

673
            instances.append(
674
675
676
677
678
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
679
680
                ),
            )
681

682
        for prompt_start in range(0, len(prompts), concurrency_limit):
683
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
684
685
686

            token_iter = range(max_tokens)
            if use_tqdm:
687
688
689
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
690
691
692
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
693
694
                    "reflect instance-level progress."
                )
695
696
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
697
698
                    sum((instance.beams for instance in instances_batch), [])
                )
699
700
                pos = [0] + list(
                    itertools.accumulate(
701
702
703
                        len(instance.beams) for instance in instances_batch
                    )
                )
704
                instance_start_and_end: list[tuple[int, int]] = list(
705
706
                    zip(pos[:-1], pos[1:])
                )
707
708
709
710
711
712

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
713
714
715
716
717
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
718
719
720

                # only runs for one step
                # we don't need to use tqdm here
721
722
723
724
725
726
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
727

728
729
730
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
731
732
733
734
735
736
737
738
739
740
741
742
743
744
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
745
                                    logprobs=current_beam.logprobs + [logprobs],
746
                                    lora_request=current_beam.lora_request,
747
748
749
750
751
752
753
754
755
756
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
757
758
759
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
760
761
762
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
763
                    instance.beams = sorted_beams[:beam_width]
764
765
766
767

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
768
769
770
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
771
772
773
774
775
776
777
778
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

779
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
780
        self,
781
782
783
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        chat_template: str | None = None,
784
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
785
        add_generation_prompt: bool = True,
786
        continue_final_message: bool = False,
787
788
789
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
790
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
791
        """
792
793
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
794

795
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
796
        Returns:
797
798
799
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
800
        """
801
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
802

803
804
        # Handle multi and single conversations
        if is_list_of(messages, list):
805
            # messages is list[list[...]]
806
            list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
807
        else:
808
            # messages is list[...]
809
            list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
810

811
        tokenizer = self.get_tokenizer()
812
        model_config = self.model_config
813
814
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
815
            tools,
816
817
            chat_template_content_format,
            tokenizer,
818
            model_config=model_config,
819
820
        )

821
822
823
824
825
826
827
828
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

829
        prompts: list[TokensPrompt] = []
830
831

        for msgs in list_of_messages:
832
833
834
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
835
            conversation, mm_data, mm_uuids = parse_chat_messages(
836
837
838
839
                msgs,
                model_config,
                content_format=resolved_content_format,
            )
840
841

            if isinstance(tokenizer, MistralTokenizer):
842
                prompt_token_ids = apply_mistral_chat_template(
843
844
                    tokenizer,
                    messages=msgs,
845
                    **_chat_template_kwargs,
846
847
                )
            else:
848
                prompt_str = apply_hf_chat_template(
849
                    tokenizer=tokenizer,
850
                    conversation=conversation,
851
                    model_config=model_config,
852
                    **_chat_template_kwargs,
853
                )
854
855
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
856
857
858
                prompt_token_ids = tokenizer.encode(
                    prompt_str, add_special_tokens=False
                )
859

860
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
861
862
863
864

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

865
866
867
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

868
869
870
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

871
            prompts.append(prompt)
872

873
874
875
876
        return prompts

    def chat(
        self,
877
878
879
880
881
882
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | list[SamplingParams] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
883
884
885
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
886
887
888
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
952
        return self.generate(
953
            prompts,
954
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
955
956
957
958
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

959
960
    def encode(
        self,
961
962
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
963
        *,
964
965
966
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
967
        pooling_task: PoolingTask | None = None,
968
        tokenization_kwargs: dict[str, Any] | None = None,
969
    ) -> list[PoolingRequestOutput]:
970
971
        """Apply pooling to the hidden states corresponding to the input
        prompts.
972

973
        This class automatically batches the given prompts, considering
974
975
976
977
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
978
            prompts: The prompts to the LLM. You may pass a sequence of prompts
979
                for batch inference. See [PromptType][vllm.inputs.PromptType]
980
                for more details about the format of each prompt.
981
982
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
983
984
985
986
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
987
            lora_request: LoRA request to use for generation, if any.
988
            pooling_task: Override the pooling task to use.
989
990
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
991
992

        Returns:
993
            A list of `PoolingRequestOutput` objects containing the
994
            pooled hidden states in the same order as the input prompts.
995
996

        Note:
997
            Using `prompts` and `prompt_token_ids` as keyword parameters is
998
            considered legacy and may be deprecated in the future. You should
999
            instead pass them via the `inputs` parameter.
1000
        """
1001

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
        error_str = (
            "pooling_task required for `LLM.encode`\n"
            "Please use one of the more specific methods or set the "
            "pooling_task when using `LLM.encode`:\n"
            "  - For embeddings, use `LLM.embed(...)` "
            'or `pooling_task="embed"`.\n'
            "  - For classification logits, use `LLM.classify(...)` "
            'or `pooling_task="classify"`.\n'
            "  - For similarity scores, use `LLM.score(...)`.\n"
            "  - For rewards, use `LLM.reward(...)` "
            'or `pooling_task="token_classify"`\n'
            "  - For token classification, "
            'use `pooling_task="token_classify"`\n'
            '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
        )
1017

1018
        if pooling_task is None:
1019
            raise ValueError(error_str)
1020

1021
        model_config = self.model_config
1022
        runner_type = model_config.runner_type
1023
        if runner_type != "pooling":
1024
1025
1026
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1027
1028
                "pooling model."
            )
1029

1030
1031
1032
1033
1034
1035
1036
1037
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1038
1039
                    "offline inference example for more details."
                )
1040
1041
1042
1043
1044
1045

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1046

1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
        if io_processor_prompt:
            assert self.io_processor is not None
            if is_list_of(pooling_params, PoolingParams):
                validated_pooling_params: list[PoolingParams] = []
                for param in as_iter(pooling_params):
                    validated_pooling_params.append(
                        self.io_processor.validate_or_generate_params(param)
                    )
                pooling_params = validated_pooling_params
            else:
                assert not isinstance(pooling_params, Sequence)
                pooling_params = self.io_processor.validate_or_generate_params(
                    pooling_params
                )
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

        if pooling_task not in self.supported_tasks:
            raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens

1075
        self._validate_and_add_requests(
1076
            prompts=prompts,
1077
            params=pooling_params,
1078
            use_tqdm=use_tqdm,
1079
            lora_request=lora_request,
1080
            tokenization_kwargs=tokenization_kwargs,
1081
1082
        )

1083
        outputs = self._run_engine(use_tqdm=use_tqdm)
1084
1085

        model_outputs = self.engine_class.validate_outputs(
1086
1087
            outputs, PoolingRequestOutput
        )
1088
1089
1090
1091
1092

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
1093
1094
                model_output=model_outputs
            )
1095
1096

            return [
1097
1098
1099
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1100
1101
1102
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1103
1104
1105
                    prompt_token_ids=[],
                    finished=True,
                )
1106
1107
1108
            ]
        else:
            return model_outputs
1109

1110
1111
    def embed(
        self,
1112
        prompts: PromptType | Sequence[PromptType],
1113
        *,
1114
1115
1116
1117
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1118
        tokenization_kwargs: dict[str, Any] | None = None,
1119
    ) -> list[EmbeddingRequestOutput]:
1120
1121
1122
1123
1124
1125
1126
1127
1128
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1129
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1130
                for more details about the format of each prompt.
1131
1132
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1133
1134
1135
1136
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1137
1138
1139
            lora_request: LoRA request to use for generation, if any.

        Returns:
1140
            A list of `EmbeddingRequestOutput` objects containing the
1141
1142
            embedding vectors in the same order as the input prompts.
        """
1143
        if "embed" not in self.supported_tasks:
1144
1145
            raise ValueError(
                "Embedding API is not supported by this model. "
1146
1147
                "Try converting the model using `--convert embed`."
            )
1148

1149
1150
1151
1152
1153
1154
1155
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1156
            tokenization_kwargs=tokenization_kwargs,
1157
        )
1158
1159
1160
1161
1162

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1163
        prompts: PromptType | Sequence[PromptType],
1164
        *,
1165
1166
1167
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1168
        tokenization_kwargs: dict[str, Any] | None = None,
1169
    ) -> list[ClassificationRequestOutput]:
1170
1171
1172
1173
1174
1175
1176
1177
1178
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1179
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1180
                for more details about the format of each prompt.
1181
1182
1183
1184
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1185
            lora_request: LoRA request to use for generation, if any.
1186
1187
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1188
        Returns:
1189
            A list of `ClassificationRequestOutput` objects containing the
1190
1191
            embedding vectors in the same order as the input prompts.
        """
1192
        if "classify" not in self.supported_tasks:
1193
            raise ValueError(
1194
                "Classification API is not supported by this model. "
1195
1196
                "Try converting the model using `--convert classify`."
            )
1197

1198
1199
1200
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1201
            pooling_params=pooling_params,
1202
1203
            lora_request=lora_request,
            pooling_task="classify",
1204
            tokenization_kwargs=tokenization_kwargs,
1205
        )
1206
1207
1208

        return [ClassificationRequestOutput.from_base(item) for item in items]

1209
1210
    def reward(
        self,
1211
        prompts: PromptType | Sequence[PromptType],
1212
1213
        /,
        *,
1214
1215
1216
1217
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1218
        tokenization_kwargs: dict[str, Any] | None = None,
1219
1220
1221
1222
1223
1224
1225
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1226
                for more details about the format of each prompt.
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1245
            pooling_task="token_classify",
1246
            tokenization_kwargs=tokenization_kwargs,
1247
1248
        )

1249
1250
    def _embedding_score(
        self,
1251
        tokenizer: TokenizerLike,
1252
1253
1254
1255
1256
1257
        text_1: list[str | TextPrompt | TokensPrompt],
        text_2: list[str | TextPrompt | TokensPrompt],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1258
        tokenization_kwargs: dict[str, Any] | None = None,
1259
1260
    ) -> list[ScoringRequestOutput]:
        encoded_output: list[PoolingRequestOutput] = self.encode(
1261
            text_1 + text_2,
1262
            truncate_prompt_tokens=truncate_prompt_tokens,
1263
1264
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1265
            pooling_params=pooling_params,
1266
            pooling_task="embed",
1267
            tokenization_kwargs=tokenization_kwargs,
1268
        )
1269

1270
1271
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
1272
1273
1274
1275

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1276
1277
1278
        scores = _cosine_similarity(
            tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2
        )
1279

1280
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1281
1282
1283
1284
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1285
        tokenizer: TokenizerLike,
1286
1287
1288
1289
1290
1291
        data_1: list[str] | list[ScoreContentPartParam],
        data_2: list[str] | list[ScoreContentPartParam],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1292
        tokenization_kwargs: dict[str, Any] | None = None,
1293
        score_template: str | None = None,
1294
    ) -> list[ScoringRequestOutput]:
1295
        model_config = self.model_config
1296
1297

        if isinstance(tokenizer, MistralTokenizer):
1298
            raise ValueError("Score API is not supported for Mistral tokenizer")
1299

1300
1301
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1302

1303
1304
1305
1306
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        pooling_params.verify("score", model_config)
1307
        pooling_params_list = list[PoolingParams]()
1308

1309
1310
        local_kwargs = tokenization_kwargs or {}
        tokenization_kwargs = local_kwargs.copy()
1311

1312
1313
1314
        _validate_truncation_size(
            model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
        )
1315

1316
        prompts = list[PromptType]()
1317

1318
1319
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1320
1321
1322
1323
1324
1325
1326
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1327
                score_template=score_template,
1328
            )
1329

1330
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1331
1332
1333
1334
1335
1336
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)
1337

1338
            prompts.append(engine_prompt)
1339
1340

        self._validate_and_add_requests(
1341
            prompts=prompts,
1342
            params=pooling_params_list,
1343
            use_tqdm=use_tqdm,
1344
1345
1346
1347
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1348
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1349
1350
1351

        return [ScoringRequestOutput.from_base(item) for item in items]

1352
1353
    def score(
        self,
1354
1355
        data_1: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
        data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
1356
        /,
1357
        *,
1358
1359
1360
1361
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1362
        chat_template: str | None = None,
1363
    ) -> list[ScoringRequestOutput]:
1364
1365
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1366

1367
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1368
1369
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1370
        The input pairs are used to build a list of prompts for the
1371
1372
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1373
1374
1375
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1376
        appropriate multi-modal models. For multi-modal inputs, ensure the
1377
        prompt structure matches the model's expected input format.
1378
1379

        Args:
1380
1381
1382
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1383
                the `data_2` list.
1384
            data_2: The data to pair with the query to form the input to
1385
                the LLM. Can be text or multi-modal data. See [PromptType]
1386
                [vllm.inputs.PromptType] for more details about the format of
1387
                each prompt.
1388
1389
1390
1391
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1392
            lora_request: LoRA request to use for generation, if any.
1393
1394
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1395
1396
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1397
        Returns:
1398
            A list of `ScoringRequestOutput` objects containing the
1399
1400
            generated scores in the same order as the input prompts.
        """
1401
        model_config = self.model_config
1402
        runner_type = model_config.runner_type
1403
        if runner_type != "pooling":
1404
1405
1406
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1407
1408
                "pooling model."
            )
1409

1410
1411
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1412
1413
1414
1415
1416
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1417

1418
1419
1420
1421
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1422
            raise ValueError("Score API is only enabled for num_labels == 1.")
1423

1424
1425
1426
1427
1428
        if not model_config.is_cross_encoder and chat_template is not None:
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1429
1430
1431
        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1432
        tokenizer = self.get_tokenizer()
1433

1434
        if not model_config.is_multimodal_model:
1435

1436
            def check_data_type(
1437
1438
1439
                data: SingletonPrompt
                | Sequence[SingletonPrompt]
                | ScoreMultiModalParam,
1440
            ):
1441
                if isinstance(data, dict) and "content" in data:
1442
1443
1444
1445
                    raise ValueError(
                        "ScoreMultiModalParam is not supported "
                        f"for {model_config.architecture}"
                    )
1446
1447
1448
1449
1450
1451
1452

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
1453
1454
1455
                        raise ValueError(
                            "Multi-modal prompt is not supported for scoring"
                        )
1456
1457
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
1458
1459
                            cast(TokensPrompt, prompt)["prompt_token_ids"]
                        )
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]
1474

1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1488

1489
        if model_config.is_cross_encoder:
1490
1491
1492
1493
1494
1495
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1496
                pooling_params,
1497
                lora_request,
1498
                score_template=chat_template,
1499
            )
1500
        else:
1501
1502
            return self._embedding_score(
                tokenizer,
1503
1504
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1505
1506
                truncate_prompt_tokens,
                use_tqdm,
1507
                pooling_params,
1508
1509
                lora_request,
            )
1510

1511
1512
1513
1514
1515
1516
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1517
1518
1519
1520
1521
1522
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1523

1524
1525
1526
1527
1528
1529
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1530
        Args:
1531
1532
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1533
                is forgotten. Level 1 sleep is good for sleeping and waking
1534
1535
1536
1537
1538
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1539
                sleep is good for sleeping and waking up the engine to run a
1540
                different model or update the model, where previous model
1541
                weights are not needed. It reduces CPU memory pressure.
1542
        """
1543
        self.reset_prefix_cache()
1544
1545
        self.llm_engine.sleep(level=level)

1546
    def wake_up(self, tags: list[str] | None = None):
1547
        """
1548
1549
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1550

1551
        Args:
1552
1553
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1554
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1555
                wake_up should be called with all tags (or None) before the
1556
1557
1558
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1559

1560
1561
1562
1563
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1564
            A `MetricSnapshot` instance capturing the current state
1565
1566
1567
1568
1569
1570
1571
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1572
1573
    def _validate_and_add_requests(
        self,
1574
1575
1576
1577
1578
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        params: SamplingParams
        | Sequence[SamplingParams]
        | PoolingParams
        | Sequence[PoolingParams],
1579
        *,
1580
1581
1582
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None,
        priority: list[int] | None = None,
1583
        tokenization_kwargs: dict[str, Any] | None = None,
1584
    ) -> None:
1585
        if isinstance(prompts, (str, dict)):
1586
            # Convert a single prompt to a list.
1587
            prompts = [prompts]  # type: ignore[list-item]
1588

1589
        num_requests = len(prompts)
1590
        if isinstance(params, Sequence) and len(params) != num_requests:
1591
1592
1593
1594
1595
            raise ValueError("The lengths of prompts and params must be the same.")
        if isinstance(lora_request, Sequence) and len(lora_request) != num_requests:
            raise ValueError(
                "The lengths of prompts and lora_request must be the same."
            )
1596
1597
1598
1599
1600
1601
        if priority is not None and len(priority) != num_requests:
            raise ValueError(
                "The lengths of prompts "
                f"({num_requests}) and priority ({len(priority)}) "
                "must be the same."
            )
1602
1603

        for sp in params if isinstance(params, Sequence) else (params,):
1604
1605
1606
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1607

Zhuohan Li's avatar
Zhuohan Li committed
1608
        # Add requests to the engine.
1609
1610
        it = prompts
        if use_tqdm:
1611
1612
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1613

1614
        added_request_ids: list[str] = []
1615

1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
        try:
            for i, prompt in enumerate(it):
                if isinstance(prompt, dict):
                    self._validate_mm_data_and_uuids(
                        prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")
                    )
                request_id = self._add_request(
                    prompt,
                    params[i] if isinstance(params, Sequence) else params,
                    lora_request=lora_request[i]
                    if isinstance(lora_request, Sequence)
                    else lora_request,
                    priority=priority[i] if priority else 0,
1629
                    tokenization_kwargs=tokenization_kwargs,
1630
1631
1632
1633
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1634
                self.llm_engine.abort_request(added_request_ids, internal=True)
1635
            raise e
1636

1637
    def _validate_mm_data_and_uuids(
1638
        self,
1639
1640
        multi_modal_data: Any | None,  # MultiModalDataDict
        multi_modal_uuids: Any | None,  # MultiModalUUIDDict
1641
1642
1643
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1644
        then its corresponding UUID must be set.
1645
1646
1647
1648
1649
1650
1651
1652
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
1653
1654
1655
1656
1657
1658
1659
1660
                        if (
                            multi_modal_uuids is None
                            or modality not in multi_modal_uuids
                            or multi_modal_uuids[  # noqa: E501
                                modality
                            ]
                            is None
                        ):
1661
1662
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
1663
1664
                                f"but UUID is not provided"
                            )
1665
                        else:
1666
1667
1668
1669
                            if (
                                len(multi_modal_uuids[modality]) <= i
                                or multi_modal_uuids[modality][i] is None
                            ):
1670
1671
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
1672
1673
                                    f"but UUID is not provided"
                                )
1674
            else:
1675
1676
1677
1678
1679
1680
1681
1682
1683
                if data is None and (
                    multi_modal_uuids is None
                    or modality not in multi_modal_uuids
                    or multi_modal_uuids[modality] is None
                ):
                    raise ValueError(
                        f"Multi-modal data for {modality} is None"
                        f" but UUID is not provided"
                    )
1684

1685
1686
1687
1688
    def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1689
        params: SamplingParams | PoolingParams,
1690
        *,
1691
        lora_request: LoRARequest | None,
1692
        priority: int,
1693
        tokenization_kwargs: dict[str, Any] | None = None,
1694
1695
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
        """Use the Processor to process inputs for LLMEngine."""
1696
1697
1698

        local_kwargs = tokenization_kwargs or {}
        tokenization_kwargs = local_kwargs.copy()
1699
1700
1701
1702
1703
        _validate_truncation_size(
            self.model_config.max_model_len,
            params.truncate_prompt_tokens,
            tokenization_kwargs,
        )
1704

1705
        engine_request = self.input_processor.process_inputs(
1706
1707
1708
1709
1710
1711
1712
1713
            request_id,
            engine_prompt,
            params,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            priority=priority,
        )
        return engine_request, tokenization_kwargs
1714

1715
    def _add_request(
nunjunj's avatar
nunjunj committed
1716
        self,
1717
        prompt: PromptType,
1718
1719
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1720
        priority: int = 0,
1721
        tokenization_kwargs: dict[str, Any] | None = None,
1722
    ) -> str:
1723
        prompt_text, _, _ = get_prompt_components(prompt)
1724
        request_id = str(next(self.request_counter))
1725
1726

        engine_request, tokenization_kwargs = self._process_inputs(
1727
            request_id,
1728
            prompt,
1729
1730
            params,
            lora_request=lora_request,
1731
            priority=priority,
1732
            tokenization_kwargs=tokenization_kwargs,
1733
1734
1735
1736
1737
1738
1739
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1740
            tokenization_kwargs=tokenization_kwargs,
1741
            priority=priority,
1742
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1743
        )
1744
        return engine_request.request_id
1745

1746
    def _run_engine(
1747
1748
        self, *, use_tqdm: bool | Callable[..., tqdm] = True
    ) -> list[RequestOutput | PoolingRequestOutput]:
1749
1750
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1751
            num_requests = self.llm_engine.get_num_unfinished_requests()
1752
1753
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1754
1755
1756
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1757
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1758
            )
1759

Zhuohan Li's avatar
Zhuohan Li committed
1760
        # Run the engine.
1761
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1762
1763
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1764
1765
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1766
            for output in step_outputs:
1767
                if output.finished:
1768
1769
                    outputs.append(output)
                    if use_tqdm:
1770
1771
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1772
                            n = len(output.outputs)
1773
                            assert output.prompt_token_ids is not None
1774
                            total_in_toks += len(output.prompt_token_ids) * n
1775
1776
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1777
1778
1779
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1780
1781
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1782
1783
                                f"output: {out_spd:.2f} toks/s"
                            )
1784
                            pbar.update(n)
1785
1786
                        else:
                            pbar.update(1)
1787
1788
                        if pbar.n == num_requests:
                            pbar.refresh()
1789

1790
1791
        if use_tqdm:
            pbar.close()
lizhigong's avatar
lizhigong committed
1792

1793
1794
1795
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1796
        return sorted(outputs, key=lambda x: int(x.request_id))
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809

    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr