llm.py 74.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Callable, Sequence
7
from typing import TYPE_CHECKING, Any, TypeAlias, cast
8

9
import cloudpickle
10
import torch.nn as nn
11
from pydantic import ValidationError
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar
14

15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
22
    AttentionConfig,
23
    CompilationConfig,
24
    PoolerConfig,
25
    ProfilerConfig,
26
27
28
    StructuredOutputsConfig,
    is_init_field,
)
29
from vllm.config.compilation import CompilationMode
30
from vllm.config.model import (
31
32
    ConvertOption,
    HfOverrides,
33
    ModelDType,
34
    RunnerOption,
35
    TokenizerMode,
36
)
37
from vllm.engine.arg_utils import EngineArgs
38
39
40
41
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
)
42
from vllm.entrypoints.pooling.score.utils import (
43
    ScoreData,
44
45
46
47
    ScoreMultiModalParam,
    _cosine_similarity,
    compress_token_type_ids,
    get_score_prompt,
48
    validate_score_input,
49
)
50
from vllm.entrypoints.utils import log_non_default_args
51
52
from vllm.inputs import (
    DataPrompt,
53
54
    EmbedsPrompt,
    ExplicitEncoderDecoderPrompt,
55
56
57
58
59
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
60
from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt
61
from vllm.logger import init_logger
62
from vllm.lora.request import LoRARequest
63
from vllm.model_executor.layers.quantization import QuantizationMethods
64
65
66
67
68
69
70
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
71
from vllm.platforms import current_platform
72
from vllm.pooling_params import PoolingParams
73
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
74
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
75
from vllm.tasks import PoolingTask
76
77
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
yhu422's avatar
yhu422 committed
78
from vllm.usage.usage_lib import UsageContext
79
from vllm.utils.collection_utils import as_iter, is_list_of
80
from vllm.utils.counter import Counter
81
from vllm.v1.engine.llm_engine import LLMEngine
82
from vllm.v1.sample.logits_processor import LogitsProcessor
83

84
85
86
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

87
88
logger = init_logger(__name__)

89
90
_R = TypeVar("_R", default=Any)

91
92
93
EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt]

94
95

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
100
101
102
103
104
105
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
106
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
107
108
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
109
110
111
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
112
113
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
114
115
116
117
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
118
        allowed_media_domains: If set, only media URLs that belong to this
119
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
120
121
122
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
123
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
124
125
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
126
        quantization: The method used to quantize the model weights. Currently,
127
            we support "awq", "gptq", and "fp8" (experimental).
128
129
130
131
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
132
133
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
134
135
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
136
137
138
139
140
141
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
142
143
144
145
146
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
147
            compared with using gpu_memory_utilization. Note that
148
149
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
150
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
151
152
153
154
155
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
156
157
158
159
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
160
161
162
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
163
        enable_return_routed_experts: Whether to return routed experts.
164
165
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
166
        hf_token: The token to use as HTTP bearer authorization for remote files
167
            . If `True`, will use the token generated when running
168
            `huggingface-cli login` (stored in `~/.huggingface`).
169
170
171
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
172
173
174
175
176
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
177
        pooler_config: Initialize non-default pooling config for the pooling
178
            model. e.g. `PoolerConfig(seq_pooling_type="MEAN", normalize=False)`.
179
        compilation_config: Either an integer or a dictionary. If it is an
180
            integer, it is used as the mode of compilation optimization. If it
181
            is a dictionary, it can specify the full compilation configuration.
182
183
184
185
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
186
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
187

188
189
    Note:
        This class is intended to be used for offline inference. For online
190
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
191
    """
192
193
194
195

    def __init__(
        self,
        model: str,
196
        *,
197
198
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
199
        tokenizer: str | None = None,
200
        tokenizer_mode: TokenizerMode | str = "auto",
201
        skip_tokenizer_init: bool = False,
202
        trust_remote_code: bool = False,
203
        allowed_local_media_path: str = "",
204
        allowed_media_domains: list[str] | None = None,
205
        tensor_parallel_size: int = 1,
206
        dtype: ModelDType = "auto",
207
208
209
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
210
        seed: int = 0,
211
        gpu_memory_utilization: float = 0.9,
212
        swap_space: float = 4,
213
        cpu_offload_gb: float = 0,
214
        enforce_eager: bool = False,
215
        enable_return_routed_experts: bool = False,
216
        disable_custom_all_reduce: bool = False,
217
218
219
220
221
222
223
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
224
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
225
        attention_config: dict[str, Any] | AttentionConfig | None = None,
226
227
228
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
229
        **kwargs: Any,
230
    ) -> None:
231
        """LLM constructor."""
232

233
234
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
235

236
237
238
239
240
241
242
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

243
        if "kv_transfer_config" in kwargs and isinstance(
244
245
            kwargs["kv_transfer_config"], dict
        ):
246
            from vllm.config.kv_transfer import KVTransferConfig
247

248
249
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
250
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
251
252
253
254
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
255
256
257
                    raw_config_dict,
                    e,
                )
258
259
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
260
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
261

262
263
264
        if hf_overrides is None:
            hf_overrides = {}

265
266
267
268
269
270
271
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
272

273
274
275
276
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
277
        else:
278
279
280
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
281

282
283
284
285
286
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
287

288
        # warn about single-process data parallel usage.
289
290
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
291
292
293
294
295
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
296
            raise ValueError(
297
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
298
299
300
301
302
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
303
        engine_args = EngineArgs(
304
            model=model,
305
306
            runner=runner,
            convert=convert,
307
            tokenizer=tokenizer,
308
            tokenizer_mode=tokenizer_mode,
309
            skip_tokenizer_init=skip_tokenizer_init,
310
            trust_remote_code=trust_remote_code,
311
            allowed_local_media_path=allowed_local_media_path,
312
            allowed_media_domains=allowed_media_domains,
313
314
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
315
            quantization=quantization,
316
            revision=revision,
317
            tokenizer_revision=tokenizer_revision,
318
319
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
320
            kv_cache_memory_bytes=kv_cache_memory_bytes,
321
            swap_space=swap_space,
322
            cpu_offload_gb=cpu_offload_gb,
323
            enforce_eager=enforce_eager,
324
            enable_return_routed_experts=enable_return_routed_experts,
325
            disable_custom_all_reduce=disable_custom_all_reduce,
326
            hf_token=hf_token,
327
            hf_overrides=hf_overrides,
328
            mm_processor_kwargs=mm_processor_kwargs,
329
            pooler_config=pooler_config,
330
            structured_outputs_config=structured_outputs_instance,
331
            profiler_config=profiler_config_instance,
332
            attention_config=attention_config_instance,
333
            compilation_config=compilation_config_instance,
334
            logits_processors=logits_processors,
335
336
            **kwargs,
        )
337

338
339
        log_non_default_args(engine_args)

340
        self.llm_engine = LLMEngine.from_engine_args(
341
342
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
343
        self.engine_class = type(self.llm_engine)
344

345
        self.request_counter = Counter()
346
        self.default_sampling_params: dict[str, Any] | None = None
347

348
349
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
350
351
        self.supported_tasks = supported_tasks

352
        self.model_config = self.llm_engine.model_config
353
        self.input_processor = self.llm_engine.input_processor
354
        self.io_processor = self.llm_engine.io_processor
355

356
357
358
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

359
    def get_tokenizer(self) -> TokenizerLike:
360
        return self.llm_engine.get_tokenizer()
361

362
    def reset_mm_cache(self) -> None:
363
        self.input_processor.clear_mm_cache()
364
365
        self.llm_engine.reset_mm_cache()

366
    def get_default_sampling_params(self) -> SamplingParams:
367
        if self.default_sampling_params is None:
368
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
369
370
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
371
372
        return SamplingParams()

373
374
    def generate(
        self,
375
376
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
377
        *,
378
379
380
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
381
        tokenization_kwargs: dict[str, Any] | None = None,
382
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
383
384
        """Generates the completions for the input prompts.

385
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
386
387
388
389
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
390
            prompts: The prompts to the LLM. You may pass a sequence of prompts
391
                for batch inference. See [PromptType][vllm.inputs.PromptType]
392
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
393
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
394
395
396
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
397
                prompts and it is paired one by one with the prompt.
398
399
400
401
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
402
            lora_request: LoRA request to use for generation, if any.
403
404
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
405
406
407
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
408
            tokenization_kwargs: Overrides for `tokenizer.encode`.
Woosuk Kwon's avatar
Woosuk Kwon committed
409
410

        Returns:
411
            A list of `RequestOutput` objects containing the
412
413
            generated completions in the same order as the input prompts.
        """
414
        model_config = self.model_config
415
416
        runner_type = model_config.runner_type
        if runner_type != "generate":
417
418
419
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
420
421
                "generative model."
            )
422

423
        if sampling_params is None:
424
            sampling_params = self.get_default_sampling_params()
425

426
        self._validate_and_add_requests(
427
            prompts=prompts,
428
            params=sampling_params,
429
            use_tqdm=use_tqdm,
430
431
            lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request),
            tokenization_kwargs=tokenization_kwargs,
432
433
            priority=priority,
        )
434

435
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
436
        return self.engine_class.validate_outputs(outputs, RequestOutput)
437

438
    def _get_modality_specific_lora_reqs(
439
        self,
440
441
        prompts: PromptType | Sequence[PromptType],
        lora_request: list[LoRARequest] | LoRARequest | None,
442
    ):
443
444
445
446
447
448
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
449
450
        if (
            lora_config is None
451
            or not self.model_config.is_multimodal_model
452
453
            or (lora_config and lora_config.default_mm_loras is None)
        ):
454
455
            return lora_request

456
        if not isinstance(prompts, Sequence) or isinstance(prompts, str):
457
            prompts = [prompts]
458

459
460
461
462
463
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
464
465
466

        return [
            self._resolve_single_prompt_mm_lora(
467
                prompt,
468
469
                opt_lora_req,
                lora_config.default_mm_loras,
470
471
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
472
473
        ]

474
475
476
    def _resolve_single_prompt_mm_lora(
        self,
        prompt: PromptType,
477
478
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
479
480
481
482
    ):
        if (
            not default_mm_loras
            or not isinstance(prompt, dict)
483
            or not (mm_data := prompt.get("multi_modal_data") or {})
484
        ):
485
486
            return lora_request

487
488
489
        intersection = set(
            mm_data.keys()  # type: ignore
        ).intersection(default_mm_loras.keys())
490
491
492
493
494
495
496
497
498
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
499
500
501
                " will be skipped",
                intersection,
            )
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
517
518
                    "lora_request as we only apply one LoRARequest per prompt"
                )
519
520
521
522
523
524
525
526
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

527
528
    def collective_rpc(
        self,
529
530
        method: str | Callable[..., _R],
        timeout: float | None = None,
531
        args: tuple = (),
532
        kwargs: dict[str, Any] | None = None,
533
    ) -> list[_R]:
534
535
536
537
538
539
540
541
542
543
544
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
545
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
546
547
548
549
550
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
551

552
553
554
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
555
        """
556
557

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
558
559

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
560
        """
561
562
        Run a function directly on the model inside each worker,
        returning the result for each of them.
563
564
565
566
567
568

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
569
        """
570
        return self.llm_engine.apply_model(func)
571

572
573
    def _get_beam_search_lora_requests(
        self,
574
575
576
        lora_request: list[LoRARequest] | LoRARequest | None,
        prompts: list[TokensPrompt | TextPrompt],
    ) -> list[LoRARequest | None]:
577
        """Get the optional lora request corresponding to each prompt."""
578
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
579
            raise ValueError(
580
581
                "Lora request list should be the same length as the prompts"
            )
582
583
584
585
586
587

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

588
589
    def beam_search(
        self,
590
        prompts: list[TokensPrompt | TextPrompt],
591
        params: BeamSearchParams,
592
        lora_request: list[LoRARequest] | LoRARequest | None = None,
593
        use_tqdm: bool = False,
594
        concurrency_limit: int | None = None,
595
    ) -> list[BeamSearchOutput]:
596
597
598
599
600
601
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
602
            params: The beam search parameters.
603
            lora_request: LoRA request to use for generation, if any.
604
            use_tqdm: Whether to use tqdm to display the progress bar.
605
606
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
607
        """
608
609
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
610
611
612
613
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
614
615
        length_penalty = params.length_penalty

616
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
617

618
619
620
621
622
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
623

624
625
626
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
627
628
                "Disabling progress bar."
            )
629
630
631
632
633
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

634
635
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
636
637
638
639
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
640
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
641
            return TokensPrompt(**token_prompt_kwargs)
642

643
644
645
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
646
        beam_search_params = SamplingParams(
647
648
649
650
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
651
        )
652
        instances: list[BeamSearchInstance] = []
653

654
        for lora_req, prompt in zip(lora_requests, prompts):
655
656
657
658
659
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
660
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
661

662
663
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
664
665
666
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
667

668
            instances.append(
669
670
671
672
673
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
674
675
                ),
            )
676

677
        for prompt_start in range(0, len(prompts), concurrency_limit):
678
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
679
680
681

            token_iter = range(max_tokens)
            if use_tqdm:
682
683
684
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
685
686
687
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
688
689
                    "reflect instance-level progress."
                )
690
691
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
692
693
                    sum((instance.beams for instance in instances_batch), [])
                )
694
695
                pos = [0] + list(
                    itertools.accumulate(
696
697
698
                        len(instance.beams) for instance in instances_batch
                    )
                )
699
                instance_start_and_end: list[tuple[int, int]] = list(
700
701
                    zip(pos[:-1], pos[1:])
                )
702
703
704
705
706
707

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
708
709
710
711
712
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
713
714
715

                # only runs for one step
                # we don't need to use tqdm here
716
717
718
719
720
721
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
722

723
724
725
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
726
727
728
729
730
731
732
733
734
735
736
737
738
739
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
740
                                    logprobs=current_beam.logprobs + [logprobs],
741
                                    lora_request=current_beam.lora_request,
742
743
744
745
746
747
748
749
750
751
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
752
753
754
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
755
756
757
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
758
                    instance.beams = sorted_beams[:beam_width]
759
760
761
762

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
763
764
765
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
766
767
768
769
770
771
772
773
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

774
775
776
777
778
779
780
781
782
783
784
785
786
787
    def _get_cmpl_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            add_special_tokens=not model_config.is_encoder_decoder,
        ).with_kwargs(tokenization_kwargs)

    def _normalize_prompts(
nunjunj's avatar
nunjunj committed
788
        self,
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
        prompts: PromptType | Sequence[PromptType],
    ) -> list[EnginePrompt | EngineEncDecPrompt]:
        if isinstance(prompts, str):
            prompts = TextPrompt(prompt=prompts)

        return prompts if isinstance(prompts, Sequence) else [prompts]  # type: ignore[return-value]

    def _preprocess_cmpl_singleton(
        self,
        prompt: SingletonPrompt,
        tok_params: TokenizeParams,
        *,
        tokenize: bool,
    ) -> EnginePrompt:
        renderer = self.llm_engine.renderer

        if not isinstance(prompt, dict):
            prompt = renderer.render_completion(prompt)

        return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt

    def _preprocess_cmpl_enc_dec(
        self,
        prompt: ExplicitEncoderDecoderPrompt,
        tok_params: TokenizeParams,
    ) -> EngineEncDecPrompt:
        enc_prompt = prompt["encoder_prompt"]
        dec_prompt = prompt["decoder_prompt"]

        return EngineEncDecPrompt(
            encoder_prompt=self._preprocess_cmpl_singleton(
                enc_prompt,
                tok_params,
                # TODO: Move multi-modal processor into tokenization
                tokenize=not self.model_config.is_multimodal_model,
            ),
            decoder_prompt=(
                None
                if dec_prompt is None
                else self._preprocess_cmpl_singleton(
                    dec_prompt,
                    tok_params,
                    # TODO: Move multi-modal processor into tokenization
                    tokenize=not self.model_config.is_multimodal_model,
                )
            ),
        )

    def _preprocess_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[EnginePrompt | EngineEncDecPrompt]:
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
            A list of `TokensPrompts` objects containing the tokenized prompt
            after chat template interpolation, and the raw multi-modal inputs.
        """
        tok_params = self._get_cmpl_tok_params(tokenization_kwargs)

        engine_prompts = list[EnginePrompt | EngineEncDecPrompt]()
        for prompt in self._normalize_prompts(prompts):
            if is_explicit_encoder_decoder_prompt(prompt):
                engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params))
            else:
                # Some MM models have non-default `add_special_tokens`
                # TODO: Move multi-modal processor into tokenization
                engine_prompts.append(
                    self._preprocess_cmpl_singleton(
                        prompt,
                        tok_params,
                        tokenize=not self.model_config.is_multimodal_model,
                    )
                )

        return engine_prompts

    def _normalize_conversations(
        self,
        conversations: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
    ) -> list[list[ChatCompletionMessageParam]]:
        return conversations if is_list_of(conversations, list) else [conversations]  # type: ignore[list-item,return-value]

    def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=False,
        ).with_kwargs(tokenization_kwargs)

    def _preprocess_chat(
        self,
        conversations: list[ChatCompletionMessageParam]
891
892
        | list[list[ChatCompletionMessageParam]],
        chat_template: str | None = None,
893
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
894
        chat_template_kwargs: dict[str, Any] | None = None,
895
        add_generation_prompt: bool = True,
896
        continue_final_message: bool = False,
897
        tools: list[dict[str, Any]] | None = None,
898
        tokenization_kwargs: dict[str, Any] | None = None,
899
        mm_processor_kwargs: dict[str, Any] | None = None,
900
    ) -> list[EnginePrompt]:
nunjunj's avatar
nunjunj committed
901
        """
902
903
904
905
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
906
907

        Returns:
908
909
            A list of `TokensPrompts` objects containing the tokenized prompt
            after chat template interpolation, and the raw multi-modal inputs.
nunjunj's avatar
nunjunj committed
910
        """
911
        renderer = self.llm_engine.renderer
912

913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
                    tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
                ),
            ),
        )
        tok_params = self._get_chat_tok_params(tokenization_kwargs)

        engine_prompts = list[EnginePrompt]()
        for conversation in self._normalize_conversations(conversations):
            _, in_prompt = renderer.render_messages(conversation, chat_params)
931
            if mm_processor_kwargs is not None:
932
                in_prompt["mm_processor_kwargs"] = mm_processor_kwargs
933

934
            engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params))
935

936
        return engine_prompts
937
938
939

    def chat(
        self,
940
941
942
943
944
945
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | list[SamplingParams] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
946
947
948
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
949
950
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
951
        tokenization_kwargs: dict[str, Any] | None = None,
952
        mm_processor_kwargs: dict[str, Any] | None = None,
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
997
998
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
999
1000
1001
1002
1003

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
1004
1005
        prompts = self._preprocess_chat(
            messages,
1006
1007
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
1008
            chat_template_kwargs=chat_template_kwargs,
1009
1010
1011
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
1012
            tokenization_kwargs=tokenization_kwargs,
1013
1014
1015
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
1016
        return self.generate(
1017
            prompts,
1018
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
1019
1020
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1021
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1022
1023
        )

1024
1025
    def encode(
        self,
1026
1027
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1028
        *,
1029
1030
1031
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1032
        pooling_task: PoolingTask | None = None,
1033
        tokenization_kwargs: dict[str, Any] | None = None,
1034
    ) -> list[PoolingRequestOutput]:
1035
1036
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1037

1038
        This class automatically batches the given prompts, considering
1039
1040
1041
1042
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1043
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1044
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1045
                for more details about the format of each prompt.
1046
1047
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1048
1049
1050
1051
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1052
            lora_request: LoRA request to use for generation, if any.
1053
            pooling_task: Override the pooling task to use.
1054
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1055
1056

        Returns:
1057
            A list of `PoolingRequestOutput` objects containing the
1058
            pooled hidden states in the same order as the input prompts.
1059
        """
1060

1061
        if pooling_task is None:
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )
1077

1078
        model_config = self.model_config
1079
        runner_type = model_config.runner_type
1080
        if runner_type != "pooling":
1081
1082
1083
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1084
1085
                "pooling model."
            )
1086

1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
        if truncate_prompt_tokens is not None:
            warnings.warn(
                "The `truncate_prompt_tokens` parameter in `LLM.encode()` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1101
1102
1103
1104
1105
1106
1107
1108
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1109
1110
                    "offline inference example for more details."
                )
1111
1112
1113
1114
1115
1116
1117

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
        if io_processor_prompt:
            assert self.io_processor is not None
            if is_list_of(pooling_params, PoolingParams):
                validated_pooling_params: list[PoolingParams] = []
                for param in as_iter(pooling_params):
                    validated_pooling_params.append(
                        self.io_processor.validate_or_generate_params(param)
                    )
                pooling_params = validated_pooling_params
            else:
                assert not isinstance(pooling_params, Sequence)
                pooling_params = self.io_processor.validate_or_generate_params(
                    pooling_params
                )
1132
1133
1134
1135

        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1136
1137
1138
1139
1140
1141
1142

        if pooling_task not in self.supported_tasks:
            raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)

1143
        self._validate_and_add_requests(
1144
            prompts=prompts,
1145
            params=pooling_params,
1146
            use_tqdm=use_tqdm,
1147
            lora_request=lora_request,
1148
            tokenization_kwargs=tokenization_kwargs,
1149
1150
        )

1151
        outputs = self._run_engine(use_tqdm=use_tqdm)
1152
1153

        model_outputs = self.engine_class.validate_outputs(
1154
1155
            outputs, PoolingRequestOutput
        )
1156
1157
1158
1159
1160

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
1161
1162
                model_output=model_outputs
            )
1163
1164

            return [
1165
1166
1167
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1168
1169
1170
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1171
1172
1173
                    prompt_token_ids=[],
                    finished=True,
                )
1174
1175
1176
            ]
        else:
            return model_outputs
1177

1178
1179
    def embed(
        self,
1180
        prompts: PromptType | Sequence[PromptType],
1181
        *,
1182
1183
1184
1185
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1186
        tokenization_kwargs: dict[str, Any] | None = None,
1187
    ) -> list[EmbeddingRequestOutput]:
1188
1189
1190
1191
1192
1193
1194
1195
1196
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1197
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1198
                for more details about the format of each prompt.
1199
1200
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1201
1202
1203
1204
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1205
            lora_request: LoRA request to use for generation, if any.
1206
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1207
1208

        Returns:
1209
            A list of `EmbeddingRequestOutput` objects containing the
1210
1211
            embedding vectors in the same order as the input prompts.
        """
1212
        if "embed" not in self.supported_tasks:
1213
1214
            raise ValueError(
                "Embedding API is not supported by this model. "
1215
1216
                "Try converting the model using `--convert embed`."
            )
1217

1218
1219
1220
1221
1222
1223
        if truncate_prompt_tokens is not None:
            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1224
1225
1226
1227
1228
1229
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1230
            tokenization_kwargs=tokenization_kwargs,
1231
        )
1232
1233
1234
1235
1236

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1237
        prompts: PromptType | Sequence[PromptType],
1238
        *,
1239
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1240
        use_tqdm: bool | Callable[..., tqdm] = True,
1241
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1242
        tokenization_kwargs: dict[str, Any] | None = None,
1243
    ) -> list[ClassificationRequestOutput]:
1244
1245
1246
1247
1248
1249
1250
1251
1252
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1253
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1254
                for more details about the format of each prompt.
1255
1256
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1257
1258
1259
1260
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1261
            lora_request: LoRA request to use for generation, if any.
1262
1263
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1264
        Returns:
1265
            A list of `ClassificationRequestOutput` objects containing the
1266
1267
            embedding vectors in the same order as the input prompts.
        """
1268
        if "classify" not in self.supported_tasks:
1269
            raise ValueError(
1270
                "Classification API is not supported by this model. "
1271
1272
                "Try converting the model using `--convert classify`."
            )
1273

1274
1275
1276
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1277
            pooling_params=pooling_params,
1278
1279
            lora_request=lora_request,
            pooling_task="classify",
1280
            tokenization_kwargs=tokenization_kwargs,
1281
        )
1282
1283
1284

        return [ClassificationRequestOutput.from_base(item) for item in items]

1285
1286
    def reward(
        self,
1287
        prompts: PromptType | Sequence[PromptType],
1288
1289
        /,
        *,
1290
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1291
1292
1293
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1294
        tokenization_kwargs: dict[str, Any] | None = None,
1295
1296
1297
1298
1299
1300
1301
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1302
                for more details about the format of each prompt.
1303
1304
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1305
1306
1307
1308
1309
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1310
1311
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1323
            pooling_task="token_classify",
1324
            tokenization_kwargs=tokenization_kwargs,
1325
1326
        )

1327
1328
    def _embedding_score(
        self,
1329
1330
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1331
1332
1333
1334
1335
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
1336
    ) -> list[ScoringRequestOutput]:
1337
1338
        tokenizer = self.get_tokenizer()

1339
1340
1341
1342
1343
1344
1345
1346
        input_texts: list[str] = []
        for text in data_1 + data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Embedding scores currently do not support multimodal input."
                )
            input_texts.append(text)

1347
        encoded_output = self.encode(
1348
            input_texts,
1349
1350
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1351
            pooling_params=pooling_params,
1352
            pooling_task="embed",
1353
            tokenization_kwargs=tokenization_kwargs,
1354
        )
1355

1356
1357
        encoded_output_1 = encoded_output[0 : len(data_1)]
        encoded_output_2 = encoded_output[len(data_1) :]
1358
1359
1360
1361

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1362
        scores = _cosine_similarity(
1363
1364
1365
            tokenizer=tokenizer,
            embed_1=encoded_output_1,
            embed_2=encoded_output_2,
1366
        )
1367

1368
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1369
1370
1371
1372
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1373
1374
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1375
1376
1377
1378
1379
1380
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
        score_template: str | None,
1381
    ) -> list[ScoringRequestOutput]:
1382
        model_config = self.model_config
1383
        tokenizer = self.get_tokenizer()
1384
1385

        if isinstance(tokenizer, MistralTokenizer):
1386
            raise ValueError("Score API is not supported for Mistral tokenizer")
1387

1388
1389
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1390

1391
1392
1393
1394
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        pooling_params.verify("score", model_config)
1395
        pooling_params_list = list[PoolingParams]()
1396

1397
        prompts = list[PromptType]()
1398

1399
1400
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1401
1402
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1403
                model_config=model_config,
1404
1405
1406
1407
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1408
                score_template=score_template,
1409
1410
            )

1411
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1412
1413
1414
1415
1416
1417
1418
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1419
            prompts.append(engine_prompt)
1420
1421

        self._validate_and_add_requests(
1422
            prompts=prompts,
1423
            params=pooling_params_list,
1424
            use_tqdm=use_tqdm,
1425
1426
1427
1428
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1429
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1430
1431
1432

        return [ScoringRequestOutput.from_base(item) for item in items]

1433
1434
    def score(
        self,
1435
1436
1437
1438
1439
1440
1441
1442
        data_1: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
        data_2: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
1443
        /,
1444
        *,
1445
1446
1447
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1448
        tokenization_kwargs: dict[str, Any] | None = None,
1449
        chat_template: str | None = None,
1450
    ) -> list[ScoringRequestOutput]:
1451
1452
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1453

1454
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1455
1456
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1457
        The input pairs are used to build a list of prompts for the
1458
1459
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1460
1461
1462
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1463
        appropriate multi-modal models. For multi-modal inputs, ensure the
1464
        prompt structure matches the model's expected input format.
1465
1466

        Args:
1467
1468
1469
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1470
                the `data_2` list.
1471
            data_2: The data to pair with the query to form the input to
1472
                the LLM. Can be text or multi-modal data. See [PromptType]
1473
                [vllm.inputs.PromptType] for more details about the format of
1474
                each prompt.
1475
1476
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1477
1478
1479
1480
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1481
            lora_request: LoRA request to use for generation, if any.
1482
1483
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1484
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1485
        Returns:
1486
            A list of `ScoringRequestOutput` objects containing the
1487
1488
            generated scores in the same order as the input prompts.
        """
1489
        model_config = self.model_config
1490

1491
        runner_type = model_config.runner_type
1492
        if runner_type != "pooling":
1493
1494
1495
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1496
1497
                "pooling model."
            )
1498

1499
1500
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1501
1502
1503
1504
1505
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1506

1507
1508
1509
1510
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1511
            raise ValueError("Score API is only enabled for num_labels == 1.")
1512

1513
1514
1515
1516
1517
        if not model_config.is_cross_encoder and chat_template is not None:
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1518
1519
        is_multimodal_model = model_config.is_multimodal_model
        architecture = model_config.architecture
1520

1521
1522
1523
1524
1525
1526
        score_data_1, score_data_2 = validate_score_input(
            data_1,  # type: ignore[arg-type]
            data_2,  # type: ignore[arg-type]
            is_multimodal_model=is_multimodal_model,
            architecture=architecture,
        )
1527

1528
1529
1530
        tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
        encode_kwargs = tok_params.get_encode_kwargs()

1531
        if model_config.is_cross_encoder:
1532
            return self._cross_encoding_score(
1533
1534
                score_data_1,
                score_data_2,
1535
1536
1537
1538
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1539
                score_template=chat_template,
1540
            )
1541
        else:
1542
            return self._embedding_score(
1543
1544
                score_data_1,
                score_data_2,
1545
1546
1547
1548
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1549
            )
1550

1551
1552
1553
1554
1555
1556
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1557
1558
1559
1560
1561
1562
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1563

1564
1565
1566
1567
1568
1569
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1570
        Args:
1571
1572
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1573
                is forgotten. Level 1 sleep is good for sleeping and waking
1574
1575
1576
1577
1578
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1579
                sleep is good for sleeping and waking up the engine to run a
1580
                different model or update the model, where previous model
1581
                weights are not needed. It reduces CPU memory pressure.
1582
        """
1583
        self.reset_prefix_cache()
1584
1585
        self.llm_engine.sleep(level=level)

1586
    def wake_up(self, tags: list[str] | None = None):
1587
        """
1588
1589
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1590

1591
        Args:
1592
1593
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1594
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1595
                wake_up should be called with all tags (or None) before the
1596
1597
1598
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1599

1600
1601
1602
1603
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1604
            A `MetricSnapshot` instance capturing the current state
1605
1606
1607
1608
1609
1610
1611
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1612
1613
    def _validate_and_add_requests(
        self,
1614
        prompts: PromptType | Sequence[PromptType],
1615
1616
1617
1618
        params: SamplingParams
        | Sequence[SamplingParams]
        | PoolingParams
        | Sequence[PoolingParams],
1619
        *,
1620
        use_tqdm: bool | Callable[..., tqdm] = True,
1621
        lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
1622
        tokenization_kwargs: dict[str, Any] | None = None,
1623
        priority: list[int] | None = None,
1624
    ) -> None:
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
        in_prompts = self._normalize_prompts(prompts)
        num_requests = len(in_prompts)

        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({params}) "
                    f"and lora_request ({len(params)}) must be the same."
                )

            engine_params = params
        else:
            engine_params = [params] * num_requests

        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

            engine_lora_requests: Sequence[LoRARequest | None] = lora_request
        else:
            engine_lora_requests = [lora_request] * num_requests

        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )
        else:
            priority = [0] * num_requests

        if any(param.truncate_prompt_tokens is not None for param in engine_params):
            # TODO: Remove this after deprecating `param.truncate_prompt_tokens`
            # Then, move the code from the `else` block to the top and let
            # `self._preprocess_completion` handle prompt normalization
            engine_prompts = [
                engine_prompt
                for in_prompt, param in zip(in_prompts, engine_params)
                for engine_prompt in self._preprocess_completion(
                    [in_prompt],
                    tokenization_kwargs=merge_kwargs(
                        tokenization_kwargs,
                        dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
                    ),
                )
            ]
        else:
            engine_prompts = self._preprocess_completion(
                in_prompts,
                tokenization_kwargs=tokenization_kwargs,
1678
            )
1679

1680
        for sp in engine_params:
1681
1682
1683
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1684

Zhuohan Li's avatar
Zhuohan Li committed
1685
        # Add requests to the engine.
1686
        it = engine_prompts
1687
        if use_tqdm:
1688
1689
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1690

1691
        added_request_ids: list[str] = []
1692

1693
1694
1695
1696
        try:
            for i, prompt in enumerate(it):
                request_id = self._add_request(
                    prompt,
1697
1698
                    engine_params[i],
                    lora_request=engine_lora_requests[i],
1699
                    tokenization_kwargs=tokenization_kwargs,
1700
                    priority=priority[i],
1701
1702
1703
1704
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1705
                self.llm_engine.abort_request(added_request_ids, internal=True)
1706
            raise e
1707

1708
    def _add_request(
nunjunj's avatar
nunjunj committed
1709
        self,
1710
        prompt: PromptType,
1711
1712
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1713
        tokenization_kwargs: dict[str, Any] | None = None,
1714
        priority: int = 0,
1715
    ) -> str:
1716
        prompt_text, _, _ = get_prompt_components(prompt)
1717
        request_id = str(next(self.request_counter))
1718

1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
        if params.truncate_prompt_tokens is not None:
            params_type = type(params).__name__
            warnings.warn(
                f"The `truncate_prompt_tokens` parameter in `{params_type}` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
            )

        tok_params = self._get_cmpl_tok_params(tokenization_kwargs)

        tokenization_kwargs = tok_params.get_encode_kwargs()
        engine_request = self.input_processor.process_inputs(
1738
            request_id,
1739
            prompt,
1740
1741
            params,
            lora_request=lora_request,
1742
            tokenization_kwargs=tokenization_kwargs,
1743
            priority=priority,
1744
1745
1746
1747
1748
1749
1750
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1751
            tokenization_kwargs=tokenization_kwargs,
1752
            priority=priority,
1753
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1754
        )
1755
        return engine_request.request_id
1756

1757
    def _run_engine(
1758
1759
        self, *, use_tqdm: bool | Callable[..., tqdm] = True
    ) -> list[RequestOutput | PoolingRequestOutput]:
1760
1761
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1762
            num_requests = self.llm_engine.get_num_unfinished_requests()
1763
1764
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1765
1766
1767
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1768
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1769
            )
1770

Zhuohan Li's avatar
Zhuohan Li committed
1771
        # Run the engine.
1772
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1773
1774
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1775
1776
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1777
            for output in step_outputs:
1778
                if output.finished:
1779
1780
                    outputs.append(output)
                    if use_tqdm:
1781
1782
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1783
                            n = len(output.outputs)
1784
                            assert output.prompt_token_ids is not None
1785
                            total_in_toks += len(output.prompt_token_ids) * n
1786
1787
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1788
1789
1790
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1791
1792
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1793
1794
                                f"output: {out_spd:.2f} toks/s"
                            )
1795
                            pbar.update(n)
1796
1797
                        else:
                            pbar.update(1)
1798
1799
                        if pbar.n == num_requests:
                            pbar.refresh()
1800

1801
1802
        if use_tqdm:
            pbar.close()
1803
1804
1805
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1806
        return sorted(outputs, key=lambda x: int(x.request_id))
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819

    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr