llm.py 71.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Sequence
6
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar
13

14
import vllm.envs as envs
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
16
17
                              BeamSearchSequence,
                              create_sort_beams_key_function)
18
19
from vllm.config import (CompilationConfig, ModelDType,
                         StructuredOutputsConfig, TokenizerMode, is_init_field)
20
21
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
                                   PoolerConfig, RunnerOption)
Joe Runde's avatar
Joe Runde committed
22
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
23
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
24
                                         ChatTemplateContentFormatOption,
25
26
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
27
28
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
29
30
# yapf conflicts with isort for this block
# yapf: disable
31
32
33
34
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
35
                                          compress_token_type_ids,
36
                                          get_score_prompt)
37
# yapf: enable
38
39
from vllm.entrypoints.utils import (_validate_truncation_size,
                                    log_non_default_args)
40
41
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
                         TokensPrompt)
42
from vllm.logger import init_logger
43
from vllm.lora.request import LoRARequest
44
from vllm.model_executor.layers.quantization import QuantizationMethods
45
46
47
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
48
from vllm.plugins.io_processors import get_io_processor
49
from vllm.pooling_params import PoolingParams
50
51
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
                                  SamplingParams)
52
from vllm.tasks import PoolingTask
53
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
54
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
55
from vllm.usage.usage_lib import UsageContext
56
from vllm.utils import Counter, Device, as_iter, is_list_of
57
from vllm.v1.sample.logits_processor import LogitsProcessor
58

59
60
61
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

62
63
logger = init_logger(__name__)

64
65
_R = TypeVar("_R", default=Any)

66
67

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
70
71
72
73
74
75
76
77
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
78
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
79
80
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
81
82
83
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
84
85
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
86
87
88
89
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
95
96
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
97
        quantization: The method used to quantize the model weights. Currently,
98
            we support "awq", "gptq", and "fp8" (experimental).
99
100
101
102
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
103
104
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
105
106
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
107
108
109
110
111
112
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
113
114
115
116
117
118
119
120
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
            compared with using gpu_memory_memory_utilization. Note that
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
121
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
122
123
124
125
126
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
127
128
129
130
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
131
132
133
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
134
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
135
            When a sequence has context length larger than this, we fall back
136
137
138
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
139
140
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
141
142
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
143
        hf_token: The token to use as HTTP bearer authorization for remote files
144
            . If `True`, will use the token generated when running
145
            `huggingface-cli login` (stored in `~/.huggingface`).
146
147
148
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
149
150
151
152
153
154
155
156
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
        override_pooler_config: Initialize non-default pooling config or
            override default pooling config for the pooling model.
            e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
157
158
159
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
160
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
161

162
163
    Note:
        This class is intended to be used for offline inference. For online
164
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
165
    """
166
167
168
169

    def __init__(
        self,
        model: str,
170
        *,
171
172
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
173
        tokenizer: Optional[str] = None,
174
        tokenizer_mode: TokenizerMode = "auto",
175
        skip_tokenizer_init: bool = False,
176
        trust_remote_code: bool = False,
177
        allowed_local_media_path: str = "",
178
        tensor_parallel_size: int = 1,
179
180
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
181
        revision: Optional[str] = None,
182
        tokenizer_revision: Optional[str] = None,
183
        seed: Optional[int] = None,
184
        gpu_memory_utilization: float = 0.9,
185
        swap_space: float = 4,
186
        cpu_offload_gb: float = 0,
187
        enforce_eager: bool = False,
188
        max_seq_len_to_capture: int = 8192,
189
        disable_custom_all_reduce: bool = False,
190
        disable_async_output_proc: bool = False,
191
        hf_token: Optional[Union[bool, str]] = None,
192
        hf_overrides: Optional[HfOverrides] = None,
193
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
194
        override_pooler_config: Optional[PoolerConfig] = None,
195
196
        structured_outputs_config: Optional[Union[dict[
            str, Any], StructuredOutputsConfig]] = None,
197
        kv_cache_memory_bytes: Optional[int] = None,
198
199
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
200
201
        logits_processors: Optional[list[Union[str,
                                               type[LogitsProcessor]]]] = None,
202
        **kwargs: Any,
203
    ) -> None:
204
        """LLM constructor."""
205

206
207
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
208

209
210
211
212
213
214
215
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

216
217
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
218
            from vllm.config.kv_transfer import KVTransferConfig
219
220
221
222
223
224
225
226
227
228
229
230
231
232
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

233
234
235
        if hf_overrides is None:
            hf_overrides = {}

236
        if compilation_config is not None:
237
238
239
240
241
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
242
243
244
245
246
                    **{
                        k: v
                        for k, v in compilation_config.items()
                        if is_init_field(CompilationConfig, k)
                    })
247
248
            else:
                compilation_config_instance = compilation_config
249
        else:
250
            compilation_config_instance = CompilationConfig()
251

252
253
254
255
256
257
258
259
260
261
262
263
264
        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{
                        k: v
                        for k, v in structured_outputs_config.items()
                        if is_init_field(StructuredOutputsConfig, k)
                    })
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

Zhuohan Li's avatar
Zhuohan Li committed
265
        engine_args = EngineArgs(
266
            model=model,
267
268
            runner=runner,
            convert=convert,
269
            tokenizer=tokenizer,
270
            tokenizer_mode=tokenizer_mode,
271
            skip_tokenizer_init=skip_tokenizer_init,
272
            trust_remote_code=trust_remote_code,
273
            allowed_local_media_path=allowed_local_media_path,
274
275
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
276
            quantization=quantization,
277
            revision=revision,
278
            tokenizer_revision=tokenizer_revision,
279
280
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
281
            kv_cache_memory_bytes=kv_cache_memory_bytes,
282
            swap_space=swap_space,
283
            cpu_offload_gb=cpu_offload_gb,
284
            enforce_eager=enforce_eager,
285
            max_seq_len_to_capture=max_seq_len_to_capture,
286
            disable_custom_all_reduce=disable_custom_all_reduce,
287
            disable_async_output_proc=disable_async_output_proc,
288
            hf_token=hf_token,
289
            hf_overrides=hf_overrides,
290
            mm_processor_kwargs=mm_processor_kwargs,
291
            override_pooler_config=override_pooler_config,
292
            structured_outputs_config=structured_outputs_instance,
293
            compilation_config=compilation_config_instance,
294
            logits_processors=logits_processors,
295
296
            **kwargs,
        )
297

298
299
        log_non_default_args(engine_args)

300
301
302
303
        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
304

305
        self.request_counter = Counter()
306
        self.default_sampling_params: Union[dict[str, Any], None] = None
307

308
309
310
311
312
313
314
315
316
317
        if envs.VLLM_USE_V1:
            supported_tasks = self.llm_engine \
                .get_supported_tasks()  # type: ignore
        else:
            supported_tasks = self.llm_engine.model_config.supported_tasks

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

318
319
320
321
322
        # Load the Input/Output processor plugin if any
        io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
        self.io_processor = get_io_processor(self.llm_engine.vllm_config,
                                             io_processor_plugin)

323
324
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer()
325
326

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
327
328
329
330
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
331
            self.llm_engine.tokenizer = tokenizer
332
        else:
333
            self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
334

335
    def get_default_sampling_params(self) -> SamplingParams:
336
337
338
339
340
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
341
342
        return SamplingParams()

343
344
345
346
347
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
348
        *,
349
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
350
351
352
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
353
354
        """Generates the completions for the input prompts.

355
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
356
357
358
359
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
360
            prompts: The prompts to the LLM. You may pass a sequence of prompts
361
                for batch inference. See [PromptType][vllm.inputs.PromptType]
362
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
363
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
364
365
366
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
367
                prompts and it is paired one by one with the prompt.
368
369
370
371
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
372
            lora_request: LoRA request to use for generation, if any.
373
374
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
375
376

        Returns:
377
            A list of `RequestOutput` objects containing the
378
            generated completions in the same order as the input prompts.
379

380
381
382
383
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
384
        """
385
386
387
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
388
389
390
391
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model.")
392

393
394
        if sampling_params is None:
            # Use default sampling params.
395
            sampling_params = self.get_default_sampling_params()
396

397
398
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
399
            prompts, lora_request)
400

401
        self._validate_and_add_requests(
402
            prompts=prompts,
403
            params=sampling_params,
404
            use_tqdm=use_tqdm,
405
            lora_request=lora_request,
406
407
            priority=priority,
        )
408

409
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
410
        return self.engine_class.validate_outputs(outputs, RequestOutput)
411

412
    def _get_modality_specific_lora_reqs(
413
            self, prompts: Union[PromptType, Sequence[PromptType]],
414
415
416
417
418
419
420
421
422
423
424
425
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

426
427
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
428

429
        optional_loras = ([lora_request] * len(prompts)
430
431
432
433
434
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
435
                prompt,
436
437
                opt_lora_req,
                lora_config.default_mm_loras,
438
            ) for prompt, opt_lora_req in zip(prompts, optional_loras)
439
440
        ]

441
    def _resolve_single_prompt_mm_lora(self, prompt: PromptType,
442
443
444
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
445
446
        if (not default_mm_loras or not isinstance(prompt, dict)
                or "multi_modal_data" not in prompt):
447
448
            return lora_request

449
        prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
450

451
452
        intersection = set(prompt["multi_modal_data"].keys()) \
            .intersection(default_mm_loras.keys())
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

487
    def collective_rpc(self,
488
                       method: Union[str, Callable[..., _R]],
489
                       timeout: Optional[float] = None,
490
491
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
492
493
494
495
496
497
498
499
500
501
502
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
503
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
504
505
506
507
508
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
509

510
511
512
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
513
        """
514
515

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
516
517

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
518
        """
519
520
        Run a function directly on the model inside each worker,
        returning the result for each of them.
521
        """
522
523
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
524

525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

541
542
    def beam_search(
        self,
543
        prompts: list[Union[TokensPrompt, TextPrompt]],
544
        params: BeamSearchParams,
545
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
546
        use_tqdm: bool = False,
547
        concurrency_limit: Optional[int] = None,
548
    ) -> list[BeamSearchOutput]:
549
550
551
552
553
554
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
555
            params: The beam search parameters.
556
            lora_request: LoRA request to use for generation, if any.
557
            use_tqdm: Whether to use tqdm to display the progress bar.
558
559
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
560
        """
561
562
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
563
564
565
566
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
567
568
        length_penalty = params.length_penalty

569
570
571
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

572
573
574
575
576
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
577

578
579
580
581
582
583
584
585
586
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
                "Disabling progress bar.")
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

587
588
589
590
591
592
593
594
595
596
597
598
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
599

600
601
602
603
604
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
605
                                            temperature=temperature)
606
        instances: list[BeamSearchInstance] = []
607

608
        for lora_req, prompt in zip(lora_requests, prompts):
609
610
611
612
613
614
615
616
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

617
618
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
619
620
621
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
622

623
            instances.append(
624
625
626
627
628
629
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
630

631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
        for prompt_start in range(0, len(prompts), concurrency_limit):
            instances_batch = instances[prompt_start:prompt_start +
                                        concurrency_limit]

            token_iter = range(max_tokens)
            if use_tqdm:
                token_iter = tqdm(token_iter,
                                  desc="Beam search",
                                  unit="token",
                                  unit_scale=False)
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
                    "reflect instance-level progress.")
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
                    sum((instance.beams for instance in instances_batch), []))
                pos = [0] + list(
                    itertools.accumulate(
                        len(instance.beams) for instance in instances_batch))
                instance_start_and_end: list[tuple[int, int]] = list(
                    zip(pos[:-1], pos[1:]))

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
                    *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                      for beam in all_beams])

                # only runs for one step
                # we don't need to use tqdm here
                output = self.generate(prompts_batch,
                                       sampling_params=beam_search_params,
                                       use_tqdm=False,
                                       lora_request=lora_req_batch)

                for (start, end), instance in zip(instance_start_and_end,
                                                  instances_batch):
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    lora_request=current_beam.lora_request,
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs)

                                if token_id == tokenizer.eos_token_id and \
                                    not ignore_eos:
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
                    sorted_beams = sorted(instance_new_beams,
                                          key=sort_beams_key,
                                          reverse=True)
                    instance.beams = sorted_beams[:beam_width]
704
705
706
707
708

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
709
                                      key=sort_beams_key,
710
711
712
713
714
715
716
717
718
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

719
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
720
        self,
721
722
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
723
        chat_template: Optional[str] = None,
724
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
725
        add_generation_prompt: bool = True,
726
        continue_final_message: bool = False,
727
        tools: Optional[list[dict[str, Any]]] = None,
728
        chat_template_kwargs: Optional[dict[str, Any]] = None,
729
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
730
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
731
        """
732
733
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
734

735
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
736
        Returns:
737
738
739
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
740
        """
741
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
742

743
744
        # Handle multi and single conversations
        if is_list_of(messages, list):
745
746
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
747
                                    messages)
748
        else:
749
            # messages is list[...]
750
            list_of_messages = [
751
                cast(list[ChatCompletionMessageParam], messages)
752
            ]
753

754
        tokenizer = self.get_tokenizer()
755
756
757
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
758
            tools,
759
760
            chat_template_content_format,
            tokenizer,
761
            model_config=model_config,
762
763
        )

764
765
766
767
768
769
770
771
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

772
        prompts: list[TokensPrompt] = []
773
774

        for msgs in list_of_messages:
775
776
777
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
778
            conversation, mm_data, mm_uuids = parse_chat_messages(
779
780
781
782
783
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
784
785

            if isinstance(tokenizer, MistralTokenizer):
786
                prompt_token_ids = apply_mistral_chat_template(
787
788
                    tokenizer,
                    messages=msgs,
789
                    **_chat_template_kwargs,
790
791
                )
            else:
792
                prompt_str = apply_hf_chat_template(
793
                    tokenizer=tokenizer,
794
                    conversation=conversation,
795
                    model_config=model_config,
796
                    **_chat_template_kwargs,
797
                )
798
799
800
801
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
802

803
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
804
805
806
807

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

808
809
810
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

811
812
813
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

814
            prompts.append(prompt)
815

816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
        return prompts

    def chat(
        self,
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
        sampling_params: Optional[Union[SamplingParams,
                                        list[SamplingParams]]] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: Optional[list[dict[str, Any]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
896
        return self.generate(
897
            prompts,
898
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
899
900
901
902
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

903
904
    def encode(
        self,
905
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
906
907
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
908
        *,
909
        truncate_prompt_tokens: Optional[int] = None,
910
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
911
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
912
        pooling_task: PoolingTask = "encode",
913
        tokenization_kwargs: Optional[dict[str, Any]] = None,
914
    ) -> list[PoolingRequestOutput]:
915
916
        """Apply pooling to the hidden states corresponding to the input
        prompts.
917

918
        This class automatically batches the given prompts, considering
919
920
921
922
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
923
            prompts: The prompts to the LLM. You may pass a sequence of prompts
924
                for batch inference. See [PromptType][vllm.inputs.PromptType]
925
                for more details about the format of each prompt.
926
927
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
928
929
930
931
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
932
            lora_request: LoRA request to use for generation, if any.
933
            pooling_task: Override the pooling task to use.
934
935
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
936
937

        Returns:
938
            A list of `PoolingRequestOutput` objects containing the
939
            pooled hidden states in the same order as the input prompts.
940

941
942
943
944
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
945
        """
946
947
948
949

        if self.supported_tasks == ["encode"] and pooling_task is None:
            pooling_task = "encode"

950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
        if pooling_task is None:
            if "embed" in self.supported_tasks:
                pooling_task = "embed"
            else:
                pooling_task = "encode"

            logger.warning_once(
                "`LLM.encode` is currently using `pooling_task = %s`.\n"
                "Please use one of the more specific methods or set the "
                "task directly when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                "or `pooling_task=\"embed\"`.\n"
                "  - For classification logits, use `LLM.classify(...)` "
                "or `pooling_task=\"classify\"`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                "or `pooling_task=\"reward\"`\n"
                "  - For similarity scores, use `LLM.score(...)`.",
                pooling_task)

969
970
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
971
        if runner_type != "pooling":
972
973
974
975
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
976

977
978
979
980
        if pooling_task not in self.supported_tasks:
            raise ValueError(
                f"pooling_task must be one of {self.supported_tasks}.")

981
982
983
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
984

985
986
987
988
989
        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens
990

991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
                    "offline inference example for more details.")

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1007
        self._validate_and_add_requests(
1008
            prompts=prompts,
1009
            params=pooling_params,
1010
            use_tqdm=use_tqdm,
1011
            lora_request=lora_request,
1012
1013
        )

1014
        outputs = self._run_engine(use_tqdm=use_tqdm)
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032

        model_outputs = self.engine_class.validate_outputs(
            outputs, PoolingRequestOutput)

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
                model_output=model_outputs)

            return [
                PoolingRequestOutput[Any](request_id="",
                                          outputs=processed_outputs,
                                          prompt_token_ids=[],
                                          finished=True)
            ]
        else:
            return model_outputs
1033

1034
1035
1036
1037
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1038
        truncate_prompt_tokens: Optional[int] = None,
1039
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1040
1041
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1042
1043
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
1044
1045
1046
1047
1048
1049
1050
1051
1052
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1053
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1054
                for more details about the format of each prompt.
1055
1056
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1057
1058
1059
1060
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1061
1062
1063
            lora_request: LoRA request to use for generation, if any.

        Returns:
1064
            A list of `EmbeddingRequestOutput` objects containing the
1065
1066
            embedding vectors in the same order as the input prompts.
        """
1067
        if "embed" not in self.supported_tasks:
1068
1069
1070
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`.")
1071

1072
1073
1074
1075
1076
1077
1078
1079
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1080
1081
1082
1083
1084
1085
1086

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1087
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1088
1089
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1090
1091
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
1092
1093
1094
1095
1096
1097
1098
1099
1100
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1101
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1102
                for more details about the format of each prompt.
1103
1104
1105
1106
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1107
            lora_request: LoRA request to use for generation, if any.
1108
1109
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1110
        Returns:
1111
            A list of `ClassificationRequestOutput` objects containing the
1112
1113
            embedding vectors in the same order as the input prompts.
        """
1114
        if "classify" not in self.supported_tasks:
1115
            raise ValueError(
1116
                "Classification API is not supported by this model. "
1117
                "Try converting the model using `--convert classify`.")
1118

1119
1120
1121
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1122
            pooling_params=pooling_params,
1123
1124
1125
            lora_request=lora_request,
            pooling_task="classify",
        )
1126
1127
1128

        return [ClassificationRequestOutput.from_base(item) for item in items]

1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
    def reward(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1146
                for more details about the format of each prompt.
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
            pooling_task="encode",
        )

1168
1169
1170
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1171
1172
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1173
        truncate_prompt_tokens: Optional[int] = None,
1174
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1175
        pooling_params: Optional[PoolingParams] = None,
1176
1177
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1178

1179
        encoded_output: list[PoolingRequestOutput] = self.encode(
1180
            text_1 + text_2,
1181
            truncate_prompt_tokens=truncate_prompt_tokens,
1182
1183
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1184
            pooling_params=pooling_params,
1185
1186
            pooling_task="embed",
        )
1187

1188
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1189
            0:len(text_1)]
1190
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1191
            len(text_1):]
1192
1193
1194
1195

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1196
1197
1198
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1199
1200
1201
1202
1203
1204
1205

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1206
        tokenizer: AnyTokenizer,
1207
1208
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1209
        truncate_prompt_tokens: Optional[int] = None,
1210
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1211
        pooling_params: Optional[PoolingParams] = None,
1212
1213
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1214
        model_config = self.llm_engine.model_config
1215
1216
1217

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
1218
                "Score API is not supported for Mistral tokenizer")
1219

1220
1221
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1222

1223
1224
1225
1226
1227
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        model_config = self.llm_engine.model_config
        pooling_params.verify("score", model_config)
1228
        pooling_params_list = list[PoolingParams]()
1229

1230
        tokenization_kwargs: dict[str, Any] = {}
1231
1232

        _validate_truncation_size(model_config.max_model_len,
1233
                                  truncate_prompt_tokens, tokenization_kwargs)
1234

1235
        prompts = list[PromptType]()
1236

1237
1238
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1239
        model_config = self.llm_engine.model_config
1240

1241
1242
1243
1244
1245
1246
1247
1248
1249
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1250
            if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
1251
1252
1253
1254
1255
1256
1257
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1258
            prompts.append(engine_prompt)
1259
1260

        self._validate_and_add_requests(
1261
            prompts=prompts,
1262
            params=pooling_params_list,
1263
            use_tqdm=use_tqdm,
1264
1265
1266
1267
1268
1269
1270
1271
1272
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1273
1274
    def score(
        self,
1275
1276
1277
1278
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1279
        /,
1280
        *,
1281
        truncate_prompt_tokens: Optional[int] = None,
1282
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1283
        pooling_params: Optional[PoolingParams] = None,
1284
1285
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1286
1287
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1288

1289
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1290
1291
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1292
        The input pairs are used to build a list of prompts for the
1293
1294
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1295
1296
1297
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1298
        appropriate multi-modal models. For multi-modal inputs, ensure the
1299
        prompt structure matches the model's expected input format.
1300
1301

        Args:
1302
1303
1304
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1305
                the `data_2` list.
1306
            data_2: The data to pair with the query to form the input to
1307
                the LLM. Can be text or multi-modal data. See [PromptType]
1308
                [vllm.inputs.PromptType] for more details about the format of
1309
                each prompt.
1310
1311
1312
1313
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1314
            lora_request: LoRA request to use for generation, if any.
1315
1316
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1317
        Returns:
1318
            A list of `ScoringRequestOutput` objects containing the
1319
1320
            generated scores in the same order as the input prompts.
        """
1321
1322
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1323
        if runner_type != "pooling":
1324
1325
1326
1327
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1328

1329
1330
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1331
            raise ValueError("Score API is not supported by this model. "
1332
1333
                             "Try converting the model using "
                             "`--convert embed` or `--convert classify`.")
1334

1335
        if (model_config.is_cross_encoder
1336
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1337
            raise ValueError("Score API is only enabled for num_labels == 1.")
1338
1339
1340
1341

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1342
        tokenizer = self.get_tokenizer()
1343

1344
        if not model_config.is_multimodal_model:
1345
1346
1347
1348
1349

            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
1350
1351
                    raise ValueError("ScoreMultiModalParam is not supported "
                                     f"for {model_config.architecture}")
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1392

1393
        if model_config.is_cross_encoder:
1394
1395
1396
1397
1398
1399
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1400
                pooling_params,
1401
                lora_request)
1402
        else:
1403
1404
            return self._embedding_score(
                tokenizer,
1405
1406
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1407
1408
                truncate_prompt_tokens,
                use_tqdm,
1409
                pooling_params,
1410
                lora_request)
1411

1412
1413
1414
1415
1416
1417
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1418
1419
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1420

1421
1422
1423
1424
1425
1426
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1427
        Args:
1428
1429
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1430
                is forgotten. Level 1 sleep is good for sleeping and waking
1431
1432
1433
1434
1435
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1436
                sleep is good for sleeping and waking up the engine to run a
1437
                different model or update the model, where previous model
1438
                weights are not needed. It reduces CPU memory pressure.
1439
        """
1440
        self.reset_prefix_cache()
1441
1442
        self.llm_engine.sleep(level=level)

1443
    def wake_up(self, tags: Optional[list[str]] = None):
1444
        """
1445
1446
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1447

1448
        Args:
1449
1450
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1451
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1452
                wake_up should be called with all tags (or None) before the
1453
1454
1455
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1456

1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1471
1472
    def _validate_and_add_requests(
        self,
1473
        prompts: Union[PromptType, Sequence[PromptType]],
1474
1475
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1476
        *,
1477
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1478
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1479
        priority: Optional[list[int]] = None,
1480
    ) -> None:
1481
        if isinstance(prompts, (str, dict)):
1482
            # Convert a single prompt to a list.
1483
            prompts = [prompts]
1484

1485
        num_requests = len(prompts)
1486
        if isinstance(params, Sequence) and len(params) != num_requests:
1487
            raise ValueError("The lengths of prompts and params "
1488
                             "must be the same.")
1489
        if isinstance(lora_request,
1490
                      Sequence) and len(lora_request) != num_requests:
1491
1492
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1493

1494
        for sp in params if isinstance(params, Sequence) else (params, ):
1495
1496
1497
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1498

Zhuohan Li's avatar
Zhuohan Li committed
1499
        # Add requests to the engine.
1500
1501
        it = prompts
        if use_tqdm:
1502
1503
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1504

1505
1506
        model_config = self.llm_engine.model_config

1507
        for i, prompt in enumerate(it):
1508

1509
1510
1511
1512
1513
            if isinstance(prompt, dict):
                self._validate_mm_data_and_uuids(
                    prompt.get("multi_modal_data"),
                    prompt.get("multi_modal_uuids"))

1514
1515
1516
1517
1518
1519
1520
            param = params[i] if isinstance(params, Sequence) else params

            tokenization_kwargs: dict[str, Any] = {}
            _validate_truncation_size(model_config.max_model_len,
                                      param.truncate_prompt_tokens,
                                      tokenization_kwargs)

1521
            self._add_request(
1522
                prompt,
1523
                params[i] if isinstance(params, Sequence) else params,
1524
                tokenization_kwargs=tokenization_kwargs,
1525
1526
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
1527
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1528
            )
1529

1530
1531
1532
1533
1534
1535
1536
    def _validate_mm_data_and_uuids(
            self,
            multi_modal_data: Optional[Any],  # MultiModalDataDict
            multi_modal_uuids: Optional[Any],  # MultiModalUUIDDict
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1537
        then its corresponding UUID must be set.
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
                        if multi_modal_uuids is None or modality not in multi_modal_uuids or multi_modal_uuids[  # noqa: E501
                                modality] is None:
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
                                f"but UUID is not provided")
                        else:
                            if len(
                                    multi_modal_uuids[modality]
                            ) <= i or multi_modal_uuids[modality][i] is None:
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
                                    f"but UUID is not provided")
            else:
                if data is None and (multi_modal_uuids is None
                                     or modality not in multi_modal_uuids
                                     or multi_modal_uuids[modality] is None):
                    raise ValueError(f"Multi-modal data for {modality} is None"
                                     f" but UUID is not provided")

1565
    def _add_request(
nunjunj's avatar
nunjunj committed
1566
        self,
1567
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1568
        params: Union[SamplingParams, PoolingParams],
1569
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1570
        lora_request: Optional[LoRARequest] = None,
1571
        priority: int = 0,
1572
1573
    ) -> None:
        request_id = str(next(self.request_counter))
1574
1575
        self.llm_engine.add_request(
            request_id,
1576
            prompt,
1577
1578
            params,
            lora_request=lora_request,
1579
            tokenization_kwargs=tokenization_kwargs,
1580
            priority=priority,
nunjunj's avatar
nunjunj committed
1581
        )
1582

1583
    def _run_engine(
1584
1585
1586
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1587
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1588
1589
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1590
            num_requests = self.llm_engine.get_num_unfinished_requests()
1591
1592
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1593
1594
1595
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1596
1597
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1598
            )
1599

Zhuohan Li's avatar
Zhuohan Li committed
1600
        # Run the engine.
1601
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1602
1603
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1604
1605
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1606
            for output in step_outputs:
1607
                if output.finished:
1608
1609
                    outputs.append(output)
                    if use_tqdm:
1610
1611
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1612
                            n = len(output.outputs)
1613
                            assert output.prompt_token_ids is not None
1614
                            total_in_toks += len(output.prompt_token_ids) * n
1615
1616
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1617
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1618
1619
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1620
1621
1622
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1623
                            pbar.update(n)
1624
1625
                        else:
                            pbar.update(1)
1626
1627
                        if pbar.n == num_requests:
                            pbar.refresh()
1628

1629
1630
        if use_tqdm:
            pbar.close()
1631
1632
1633
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1634
        return sorted(outputs, key=lambda x: int(x.request_id))