llm.py 71.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Sequence
6
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar
13

14
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
15
16
                              BeamSearchSequence,
                              create_sort_beams_key_function)
17
18
from vllm.config import (CompilationConfig, ModelDType,
                         StructuredOutputsConfig, TokenizerMode, is_init_field)
19
20
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
                                   PoolerConfig, RunnerOption)
nunjunj's avatar
nunjunj committed
21
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
22
                                         ChatTemplateContentFormatOption,
23
24
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
25
26
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
27
28
# yapf conflicts with isort for this block
# yapf: disable
29
30
31
32
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
33
                                          compress_token_type_ids,
34
                                          get_score_prompt)
35
# yapf: enable
36
37
from vllm.entrypoints.utils import (_validate_truncation_size,
                                    log_non_default_args)
38
39
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
                         TokensPrompt)
40
from vllm.logger import init_logger
41
from vllm.lora.request import LoRARequest
42
from vllm.model_executor.layers.quantization import QuantizationMethods
43
44
45
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
46
from vllm.plugins.io_processors import get_io_processor
47
from vllm.pooling_params import PoolingParams
48
49
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
                                  SamplingParams)
50
from vllm.tasks import PoolingTask
51
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
52
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
53
from vllm.usage.usage_lib import UsageContext
54
from vllm.utils import Counter, Device, as_iter, is_list_of
55
from vllm.v1.engine.llm_engine import LLMEngine
56
from vllm.v1.sample.logits_processor import LogitsProcessor
57

58
59
60
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

61
62
logger = init_logger(__name__)

63
64
_R = TypeVar("_R", default=Any)

65
66

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
69
70
71
72
73
74
75
76
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
77
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
78
79
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
80
81
82
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
83
84
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
85
86
87
88
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
89
90
        allowed_media_domains: If set, only media URLs that belong to this 
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
97
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
98
        quantization: The method used to quantize the model weights. Currently,
99
            we support "awq", "gptq", and "fp8" (experimental).
100
101
102
103
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
104
105
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
106
107
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
108
109
110
111
112
113
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
114
115
116
117
118
119
120
121
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
            compared with using gpu_memory_memory_utilization. Note that
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
122
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
123
124
125
126
127
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
128
129
130
131
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
132
133
134
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
135
136
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
137
        hf_token: The token to use as HTTP bearer authorization for remote files
138
            . If `True`, will use the token generated when running
139
            `huggingface-cli login` (stored in `~/.huggingface`).
140
141
142
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
143
144
145
146
147
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
148
149
150
151
152
        pooler_config: Initialize non-default pooling config for the pooling
            model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
        override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
            argument is deprecated and will be removed in v0.12.0 or v1.0.0,
            whichever is sooner.
153
154
155
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
156
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
157

158
159
    Note:
        This class is intended to be used for offline inference. For online
160
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
161
    """
162
163
164
165

    def __init__(
        self,
        model: str,
166
        *,
167
168
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
169
        tokenizer: Optional[str] = None,
170
        tokenizer_mode: TokenizerMode = "auto",
171
        skip_tokenizer_init: bool = False,
172
        trust_remote_code: bool = False,
173
        allowed_local_media_path: str = "",
174
        allowed_media_domains: Optional[list[str]] = None,
175
        tensor_parallel_size: int = 1,
176
177
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
178
        revision: Optional[str] = None,
179
        tokenizer_revision: Optional[str] = None,
180
        seed: Optional[int] = None,
181
        gpu_memory_utilization: float = 0.9,
182
        swap_space: float = 4,
183
        cpu_offload_gb: float = 0,
184
        enforce_eager: bool = False,
185
        disable_custom_all_reduce: bool = False,
186
        hf_token: Optional[Union[bool, str]] = None,
187
        hf_overrides: Optional[HfOverrides] = None,
188
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
189
        pooler_config: Optional[PoolerConfig] = None,
190
        override_pooler_config: Optional[PoolerConfig] = None,
191
192
        structured_outputs_config: Optional[Union[dict[
            str, Any], StructuredOutputsConfig]] = None,
193
        kv_cache_memory_bytes: Optional[int] = None,
194
195
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
196
197
        logits_processors: Optional[list[Union[str,
                                               type[LogitsProcessor]]]] = None,
198
        **kwargs: Any,
199
    ) -> None:
200
        """LLM constructor."""
201

202
203
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
204

205
206
207
208
209
210
211
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

212
213
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
214
            from vllm.config.kv_transfer import KVTransferConfig
215
216
217
218
219
220
221
222
223
224
225
226
227
228
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

229
230
231
        if hf_overrides is None:
            hf_overrides = {}

232
        if compilation_config is not None:
233
234
235
236
237
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
238
239
240
241
242
                    **{
                        k: v
                        for k, v in compilation_config.items()
                        if is_init_field(CompilationConfig, k)
                    })
243
244
            else:
                compilation_config_instance = compilation_config
245
        else:
246
            compilation_config_instance = CompilationConfig()
247

248
249
250
251
252
253
254
255
256
257
258
259
260
        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{
                        k: v
                        for k, v in structured_outputs_config.items()
                        if is_init_field(StructuredOutputsConfig, k)
                    })
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

Zhuohan Li's avatar
Zhuohan Li committed
261
        engine_args = EngineArgs(
262
            model=model,
263
264
            runner=runner,
            convert=convert,
265
            tokenizer=tokenizer,
266
            tokenizer_mode=tokenizer_mode,
267
            skip_tokenizer_init=skip_tokenizer_init,
268
            trust_remote_code=trust_remote_code,
269
            allowed_local_media_path=allowed_local_media_path,
270
            allowed_media_domains=allowed_media_domains,
271
272
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
273
            quantization=quantization,
274
            revision=revision,
275
            tokenizer_revision=tokenizer_revision,
276
277
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
278
            kv_cache_memory_bytes=kv_cache_memory_bytes,
279
            swap_space=swap_space,
280
            cpu_offload_gb=cpu_offload_gb,
281
            enforce_eager=enforce_eager,
282
            disable_custom_all_reduce=disable_custom_all_reduce,
283
            hf_token=hf_token,
284
            hf_overrides=hf_overrides,
285
            mm_processor_kwargs=mm_processor_kwargs,
286
            pooler_config=pooler_config,
287
            override_pooler_config=override_pooler_config,
288
            structured_outputs_config=structured_outputs_instance,
289
            compilation_config=compilation_config_instance,
290
            logits_processors=logits_processors,
291
292
            **kwargs,
        )
293

294
295
        log_non_default_args(engine_args)

296
297
298
299
        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
300

301
        self.request_counter = Counter()
302
        self.default_sampling_params: Union[dict[str, Any], None] = None
303

304
        supported_tasks = self.llm_engine.get_supported_tasks()  # type: ignore
305
306
307
308
309

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

310
311
312
313
314
        # Load the Input/Output processor plugin if any
        io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
        self.io_processor = get_io_processor(self.llm_engine.vllm_config,
                                             io_processor_plugin)

315
316
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer()
317
318

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
319
320
321
322
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
323
            self.llm_engine.tokenizer = tokenizer
324
        else:
325
            self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
326

327
    def get_default_sampling_params(self) -> SamplingParams:
328
329
330
331
332
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
333
334
        return SamplingParams()

335
336
337
338
339
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
340
        *,
341
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
342
343
344
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
345
346
        """Generates the completions for the input prompts.

347
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
348
349
350
351
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
352
            prompts: The prompts to the LLM. You may pass a sequence of prompts
353
                for batch inference. See [PromptType][vllm.inputs.PromptType]
354
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
355
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
356
357
358
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
359
                prompts and it is paired one by one with the prompt.
360
361
362
363
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
364
            lora_request: LoRA request to use for generation, if any.
365
366
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
367
368

        Returns:
369
            A list of `RequestOutput` objects containing the
370
            generated completions in the same order as the input prompts.
371

372
373
374
375
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
376
        """
377
378
379
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
380
381
382
383
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model.")
384

385
386
        if sampling_params is None:
            # Use default sampling params.
387
            sampling_params = self.get_default_sampling_params()
388

389
390
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
391
            prompts, lora_request)
392

393
        self._validate_and_add_requests(
394
            prompts=prompts,
395
            params=sampling_params,
396
            use_tqdm=use_tqdm,
397
            lora_request=lora_request,
398
399
            priority=priority,
        )
400

401
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
402
        return self.engine_class.validate_outputs(outputs, RequestOutput)
403

404
    def _get_modality_specific_lora_reqs(
405
            self, prompts: Union[PromptType, Sequence[PromptType]],
406
407
408
409
410
411
412
413
414
415
416
417
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

418
419
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
420

421
        optional_loras = ([lora_request] * len(prompts)
422
423
424
425
426
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
427
                prompt,
428
429
                opt_lora_req,
                lora_config.default_mm_loras,
430
            ) for prompt, opt_lora_req in zip(prompts, optional_loras)
431
432
        ]

433
    def _resolve_single_prompt_mm_lora(self, prompt: PromptType,
434
435
436
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
437
438
        if (not default_mm_loras or not isinstance(prompt, dict)
                or "multi_modal_data" not in prompt):
439
440
            return lora_request

441
        prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
442

443
444
        intersection = set(prompt["multi_modal_data"].keys()) \
            .intersection(default_mm_loras.keys())
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

479
    def collective_rpc(self,
480
                       method: Union[str, Callable[..., _R]],
481
                       timeout: Optional[float] = None,
482
483
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
484
485
486
487
488
489
490
491
492
493
494
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
495
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
496
497
498
499
500
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
501

502
503
504
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
505
        """
506
507

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
508
509

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
510
        """
511
512
        Run a function directly on the model inside each worker,
        returning the result for each of them.
513
514
515
516
517
518

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
519
        """
520
        return self.llm_engine.apply_model(func)
521

522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

538
539
    def beam_search(
        self,
540
        prompts: list[Union[TokensPrompt, TextPrompt]],
541
        params: BeamSearchParams,
542
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
543
        use_tqdm: bool = False,
544
        concurrency_limit: Optional[int] = None,
545
    ) -> list[BeamSearchOutput]:
546
547
548
549
550
551
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
552
            params: The beam search parameters.
553
            lora_request: LoRA request to use for generation, if any.
554
            use_tqdm: Whether to use tqdm to display the progress bar.
555
556
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
557
        """
558
559
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
560
561
562
563
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
564
565
        length_penalty = params.length_penalty

566
567
568
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

569
570
571
572
573
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
574

575
576
577
578
579
580
581
582
583
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
                "Disabling progress bar.")
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

584
585
586
587
588
589
590
591
592
593
594
595
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
596

597
598
599
600
601
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
602
                                            temperature=temperature)
603
        instances: list[BeamSearchInstance] = []
604

605
        for lora_req, prompt in zip(lora_requests, prompts):
606
607
608
609
610
611
612
613
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

614
615
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
616
617
618
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
619

620
            instances.append(
621
622
623
624
625
626
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
627

628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
        for prompt_start in range(0, len(prompts), concurrency_limit):
            instances_batch = instances[prompt_start:prompt_start +
                                        concurrency_limit]

            token_iter = range(max_tokens)
            if use_tqdm:
                token_iter = tqdm(token_iter,
                                  desc="Beam search",
                                  unit="token",
                                  unit_scale=False)
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
                    "reflect instance-level progress.")
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
                    sum((instance.beams for instance in instances_batch), []))
                pos = [0] + list(
                    itertools.accumulate(
                        len(instance.beams) for instance in instances_batch))
                instance_start_and_end: list[tuple[int, int]] = list(
                    zip(pos[:-1], pos[1:]))

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
                    *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                      for beam in all_beams])

                # only runs for one step
                # we don't need to use tqdm here
                output = self.generate(prompts_batch,
                                       sampling_params=beam_search_params,
                                       use_tqdm=False,
                                       lora_request=lora_req_batch)

                for (start, end), instance in zip(instance_start_and_end,
                                                  instances_batch):
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    lora_request=current_beam.lora_request,
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs)

                                if token_id == tokenizer.eos_token_id and \
                                    not ignore_eos:
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
                    sorted_beams = sorted(instance_new_beams,
                                          key=sort_beams_key,
                                          reverse=True)
                    instance.beams = sorted_beams[:beam_width]
701
702
703
704
705

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
706
                                      key=sort_beams_key,
707
708
709
710
711
712
713
714
715
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

716
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
717
        self,
718
719
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
720
        chat_template: Optional[str] = None,
721
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
722
        add_generation_prompt: bool = True,
723
        continue_final_message: bool = False,
724
        tools: Optional[list[dict[str, Any]]] = None,
725
        chat_template_kwargs: Optional[dict[str, Any]] = None,
726
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
727
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
728
        """
729
730
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
731

732
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
733
        Returns:
734
735
736
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
737
        """
738
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
739

740
741
        # Handle multi and single conversations
        if is_list_of(messages, list):
742
743
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
744
                                    messages)
745
        else:
746
            # messages is list[...]
747
            list_of_messages = [
748
                cast(list[ChatCompletionMessageParam], messages)
749
            ]
750

751
        tokenizer = self.get_tokenizer()
752
753
754
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
755
            tools,
756
757
            chat_template_content_format,
            tokenizer,
758
            model_config=model_config,
759
760
        )

761
762
763
764
765
766
767
768
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

769
        prompts: list[TokensPrompt] = []
770
771

        for msgs in list_of_messages:
772
773
774
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
775
            conversation, mm_data, mm_uuids = parse_chat_messages(
776
777
778
779
780
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
781
782

            if isinstance(tokenizer, MistralTokenizer):
783
                prompt_token_ids = apply_mistral_chat_template(
784
785
                    tokenizer,
                    messages=msgs,
786
                    **_chat_template_kwargs,
787
788
                )
            else:
789
                prompt_str = apply_hf_chat_template(
790
                    tokenizer=tokenizer,
791
                    conversation=conversation,
792
                    model_config=model_config,
793
                    **_chat_template_kwargs,
794
                )
795
796
797
798
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
799

800
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
801
802
803
804

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

805
806
807
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

808
809
810
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

811
            prompts.append(prompt)
812

813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
        return prompts

    def chat(
        self,
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
        sampling_params: Optional[Union[SamplingParams,
                                        list[SamplingParams]]] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: Optional[list[dict[str, Any]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
893
        return self.generate(
894
            prompts,
895
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
896
897
898
899
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

900
901
    def encode(
        self,
902
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
903
904
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
905
        *,
906
        truncate_prompt_tokens: Optional[int] = None,
907
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
908
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
909
        pooling_task: PoolingTask = "encode",
910
        tokenization_kwargs: Optional[dict[str, Any]] = None,
911
    ) -> list[PoolingRequestOutput]:
912
913
        """Apply pooling to the hidden states corresponding to the input
        prompts.
914

915
        This class automatically batches the given prompts, considering
916
917
918
919
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
920
            prompts: The prompts to the LLM. You may pass a sequence of prompts
921
                for batch inference. See [PromptType][vllm.inputs.PromptType]
922
                for more details about the format of each prompt.
923
924
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
925
926
927
928
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
929
            lora_request: LoRA request to use for generation, if any.
930
            pooling_task: Override the pooling task to use.
931
932
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
933
934

        Returns:
935
            A list of `PoolingRequestOutput` objects containing the
936
            pooled hidden states in the same order as the input prompts.
937

938
939
940
941
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
942
        """
943
944
945
946

        if self.supported_tasks == ["encode"] and pooling_task is None:
            pooling_task = "encode"

947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
        if pooling_task is None:
            if "embed" in self.supported_tasks:
                pooling_task = "embed"
            else:
                pooling_task = "encode"

            logger.warning_once(
                "`LLM.encode` is currently using `pooling_task = %s`.\n"
                "Please use one of the more specific methods or set the "
                "task directly when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                "or `pooling_task=\"embed\"`.\n"
                "  - For classification logits, use `LLM.classify(...)` "
                "or `pooling_task=\"classify\"`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                "or `pooling_task=\"reward\"`\n"
                "  - For similarity scores, use `LLM.score(...)`.",
                pooling_task)

966
967
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
968
        if runner_type != "pooling":
969
970
971
972
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
973

974
975
976
977
        if pooling_task not in self.supported_tasks:
            raise ValueError(
                f"pooling_task must be one of {self.supported_tasks}.")

978
979
980
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
981

982
983
984
985
986
        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens
987

988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
                    "offline inference example for more details.")

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1004
        self._validate_and_add_requests(
1005
            prompts=prompts,
1006
            params=pooling_params,
1007
            use_tqdm=use_tqdm,
1008
            lora_request=lora_request,
1009
1010
        )

1011
        outputs = self._run_engine(use_tqdm=use_tqdm)
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029

        model_outputs = self.engine_class.validate_outputs(
            outputs, PoolingRequestOutput)

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
                model_output=model_outputs)

            return [
                PoolingRequestOutput[Any](request_id="",
                                          outputs=processed_outputs,
                                          prompt_token_ids=[],
                                          finished=True)
            ]
        else:
            return model_outputs
1030

1031
1032
1033
1034
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1035
        truncate_prompt_tokens: Optional[int] = None,
1036
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1037
1038
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1039
1040
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
1041
1042
1043
1044
1045
1046
1047
1048
1049
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1050
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1051
                for more details about the format of each prompt.
1052
1053
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1054
1055
1056
1057
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1058
1059
1060
            lora_request: LoRA request to use for generation, if any.

        Returns:
1061
            A list of `EmbeddingRequestOutput` objects containing the
1062
1063
            embedding vectors in the same order as the input prompts.
        """
1064
        if "embed" not in self.supported_tasks:
1065
1066
1067
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`.")
1068

1069
1070
1071
1072
1073
1074
1075
1076
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1077
1078
1079
1080
1081
1082
1083

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1084
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1085
1086
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1087
1088
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
1089
1090
1091
1092
1093
1094
1095
1096
1097
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1098
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1099
                for more details about the format of each prompt.
1100
1101
1102
1103
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1104
            lora_request: LoRA request to use for generation, if any.
1105
1106
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1107
        Returns:
1108
            A list of `ClassificationRequestOutput` objects containing the
1109
1110
            embedding vectors in the same order as the input prompts.
        """
1111
        if "classify" not in self.supported_tasks:
1112
            raise ValueError(
1113
                "Classification API is not supported by this model. "
1114
                "Try converting the model using `--convert classify`.")
1115

1116
1117
1118
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1119
            pooling_params=pooling_params,
1120
1121
1122
            lora_request=lora_request,
            pooling_task="classify",
        )
1123
1124
1125

        return [ClassificationRequestOutput.from_base(item) for item in items]

1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
    def reward(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1143
                for more details about the format of each prompt.
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
            pooling_task="encode",
        )

1165
1166
1167
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1168
1169
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1170
        truncate_prompt_tokens: Optional[int] = None,
1171
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1172
        pooling_params: Optional[PoolingParams] = None,
1173
1174
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1175

1176
        encoded_output: list[PoolingRequestOutput] = self.encode(
1177
            text_1 + text_2,
1178
            truncate_prompt_tokens=truncate_prompt_tokens,
1179
1180
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1181
            pooling_params=pooling_params,
1182
1183
            pooling_task="embed",
        )
1184

1185
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1186
            0:len(text_1)]
1187
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1188
            len(text_1):]
1189
1190
1191
1192

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1193
1194
1195
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1196
1197
1198
1199
1200
1201
1202

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1203
        tokenizer: AnyTokenizer,
1204
1205
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1206
        truncate_prompt_tokens: Optional[int] = None,
1207
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1208
        pooling_params: Optional[PoolingParams] = None,
1209
1210
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1211
        model_config = self.llm_engine.model_config
1212
1213
1214

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
1215
                "Score API is not supported for Mistral tokenizer")
1216

1217
1218
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1219

1220
1221
1222
1223
1224
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        model_config = self.llm_engine.model_config
        pooling_params.verify("score", model_config)
1225
        pooling_params_list = list[PoolingParams]()
1226

1227
        tokenization_kwargs: dict[str, Any] = {}
1228
1229

        _validate_truncation_size(model_config.max_model_len,
1230
                                  truncate_prompt_tokens, tokenization_kwargs)
1231

1232
        prompts = list[PromptType]()
1233

1234
1235
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1236
        model_config = self.llm_engine.model_config
1237

1238
1239
1240
1241
1242
1243
1244
1245
1246
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1247
            if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
1248
1249
1250
1251
1252
1253
1254
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1255
            prompts.append(engine_prompt)
1256
1257

        self._validate_and_add_requests(
1258
            prompts=prompts,
1259
            params=pooling_params_list,
1260
            use_tqdm=use_tqdm,
1261
1262
1263
1264
1265
1266
1267
1268
1269
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1270
1271
    def score(
        self,
1272
1273
1274
1275
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1276
        /,
1277
        *,
1278
        truncate_prompt_tokens: Optional[int] = None,
1279
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1280
        pooling_params: Optional[PoolingParams] = None,
1281
1282
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1283
1284
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1285

1286
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1287
1288
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1289
        The input pairs are used to build a list of prompts for the
1290
1291
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1292
1293
1294
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1295
        appropriate multi-modal models. For multi-modal inputs, ensure the
1296
        prompt structure matches the model's expected input format.
1297
1298

        Args:
1299
1300
1301
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1302
                the `data_2` list.
1303
            data_2: The data to pair with the query to form the input to
1304
                the LLM. Can be text or multi-modal data. See [PromptType]
1305
                [vllm.inputs.PromptType] for more details about the format of
1306
                each prompt.
1307
1308
1309
1310
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1311
            lora_request: LoRA request to use for generation, if any.
1312
1313
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1314
        Returns:
1315
            A list of `ScoringRequestOutput` objects containing the
1316
1317
            generated scores in the same order as the input prompts.
        """
1318
1319
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1320
        if runner_type != "pooling":
1321
1322
1323
1324
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1325

1326
1327
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1328
            raise ValueError("Score API is not supported by this model. "
1329
1330
                             "Try converting the model using "
                             "`--convert embed` or `--convert classify`.")
1331

1332
        if (model_config.is_cross_encoder
1333
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1334
            raise ValueError("Score API is only enabled for num_labels == 1.")
1335
1336
1337
1338

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1339
        tokenizer = self.get_tokenizer()
1340

1341
        if not model_config.is_multimodal_model:
1342
1343
1344
1345
1346

            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
1347
1348
                    raise ValueError("ScoreMultiModalParam is not supported "
                                     f"for {model_config.architecture}")
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1389

1390
        if model_config.is_cross_encoder:
1391
1392
1393
1394
1395
1396
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1397
                pooling_params,
1398
                lora_request)
1399
        else:
1400
1401
            return self._embedding_score(
                tokenizer,
1402
1403
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1404
1405
                truncate_prompt_tokens,
                use_tqdm,
1406
                pooling_params,
1407
                lora_request)
1408

1409
1410
1411
1412
1413
1414
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1415
1416
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1417

1418
1419
1420
1421
1422
1423
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1424
        Args:
1425
1426
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1427
                is forgotten. Level 1 sleep is good for sleeping and waking
1428
1429
1430
1431
1432
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1433
                sleep is good for sleeping and waking up the engine to run a
1434
                different model or update the model, where previous model
1435
                weights are not needed. It reduces CPU memory pressure.
1436
        """
1437
        self.reset_prefix_cache()
1438
1439
        self.llm_engine.sleep(level=level)

1440
    def wake_up(self, tags: Optional[list[str]] = None):
1441
        """
1442
1443
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1444

1445
        Args:
1446
1447
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1448
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1449
                wake_up should be called with all tags (or None) before the
1450
1451
1452
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1453

1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1466
1467
    def _validate_and_add_requests(
        self,
1468
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
1469
1470
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1471
        *,
1472
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1473
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1474
        priority: Optional[list[int]] = None,
1475
    ) -> None:
1476
        if isinstance(prompts, (str, dict)):
1477
            # Convert a single prompt to a list.
1478
            prompts = [prompts]  # type: ignore[list-item]
1479

1480
        num_requests = len(prompts)
1481
        if isinstance(params, Sequence) and len(params) != num_requests:
1482
            raise ValueError("The lengths of prompts and params "
1483
                             "must be the same.")
1484
        if isinstance(lora_request,
1485
                      Sequence) and len(lora_request) != num_requests:
1486
1487
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1488

1489
        for sp in params if isinstance(params, Sequence) else (params, ):
1490
1491
1492
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1493

Zhuohan Li's avatar
Zhuohan Li committed
1494
        # Add requests to the engine.
1495
1496
        it = prompts
        if use_tqdm:
1497
1498
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1499

1500
1501
        model_config = self.llm_engine.model_config

1502
        for i, prompt in enumerate(it):
1503

1504
1505
1506
1507
1508
            if isinstance(prompt, dict):
                self._validate_mm_data_and_uuids(
                    prompt.get("multi_modal_data"),
                    prompt.get("multi_modal_uuids"))

1509
1510
1511
1512
1513
1514
1515
            param = params[i] if isinstance(params, Sequence) else params

            tokenization_kwargs: dict[str, Any] = {}
            _validate_truncation_size(model_config.max_model_len,
                                      param.truncate_prompt_tokens,
                                      tokenization_kwargs)

1516
            self._add_request(
1517
                prompt,
1518
                params[i] if isinstance(params, Sequence) else params,
1519
                tokenization_kwargs=tokenization_kwargs,
1520
1521
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
1522
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1523
            )
1524

1525
1526
1527
1528
1529
1530
1531
    def _validate_mm_data_and_uuids(
            self,
            multi_modal_data: Optional[Any],  # MultiModalDataDict
            multi_modal_uuids: Optional[Any],  # MultiModalUUIDDict
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1532
        then its corresponding UUID must be set.
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
                        if multi_modal_uuids is None or modality not in multi_modal_uuids or multi_modal_uuids[  # noqa: E501
                                modality] is None:
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
                                f"but UUID is not provided")
                        else:
                            if len(
                                    multi_modal_uuids[modality]
                            ) <= i or multi_modal_uuids[modality][i] is None:
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
                                    f"but UUID is not provided")
            else:
                if data is None and (multi_modal_uuids is None
                                     or modality not in multi_modal_uuids
                                     or multi_modal_uuids[modality] is None):
                    raise ValueError(f"Multi-modal data for {modality} is None"
                                     f" but UUID is not provided")

1560
    def _add_request(
nunjunj's avatar
nunjunj committed
1561
        self,
1562
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1563
        params: Union[SamplingParams, PoolingParams],
1564
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1565
        lora_request: Optional[LoRARequest] = None,
1566
        priority: int = 0,
1567
1568
    ) -> None:
        request_id = str(next(self.request_counter))
1569
1570
        self.llm_engine.add_request(
            request_id,
1571
            prompt,
1572
1573
            params,
            lora_request=lora_request,
1574
            tokenization_kwargs=tokenization_kwargs,
1575
            priority=priority,
nunjunj's avatar
nunjunj committed
1576
        )
1577

1578
    def _run_engine(
1579
1580
1581
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1582
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1583
1584
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1585
            num_requests = self.llm_engine.get_num_unfinished_requests()
1586
1587
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1588
1589
1590
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1591
1592
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1593
            )
1594

Zhuohan Li's avatar
Zhuohan Li committed
1595
        # Run the engine.
1596
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1597
1598
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1599
1600
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1601
            for output in step_outputs:
1602
                if output.finished:
1603
1604
                    outputs.append(output)
                    if use_tqdm:
1605
1606
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1607
                            n = len(output.outputs)
1608
                            assert output.prompt_token_ids is not None
1609
                            total_in_toks += len(output.prompt_token_ids) * n
1610
1611
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1612
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1613
1614
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1615
1616
1617
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1618
                            pbar.update(n)
1619
1620
                        else:
                            pbar.update(1)
1621
1622
                        if pbar.n == num_requests:
                            pbar.refresh()
1623

1624
1625
        if use_tqdm:
            pbar.close()
1626
1627
1628
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1629
        return sorted(outputs, key=lambda x: int(x.request_id))