llm.py 72.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Sequence
6
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar
13

14
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
15
16
                              BeamSearchSequence,
                              create_sort_beams_key_function)
17
18
from vllm.config import (CompilationConfig, ModelDType,
                         StructuredOutputsConfig, TokenizerMode, is_init_field)
19
20
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
                                   PoolerConfig, RunnerOption)
nunjunj's avatar
nunjunj committed
21
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
22
                                         ChatTemplateContentFormatOption,
23
24
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
25
26
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
27
28
# yapf conflicts with isort for this block
# yapf: disable
29
30
31
32
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
33
                                          compress_token_type_ids,
34
                                          get_score_prompt)
35
# yapf: enable
36
37
from vllm.entrypoints.utils import (_validate_truncation_size,
                                    log_non_default_args)
38
39
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
                         TokensPrompt)
40
from vllm.inputs.parse import get_prompt_components
41
from vllm.logger import init_logger
42
from vllm.lora.request import LoRARequest
43
from vllm.model_executor.layers.quantization import QuantizationMethods
44
45
46
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
47
from vllm.plugins.io_processors import get_io_processor
48
from vllm.pooling_params import PoolingParams
49
50
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
                                  SamplingParams)
51
from vllm.tasks import PoolingTask
52
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
53
54
                                               get_cached_tokenizer,
                                               init_tokenizer_from_configs)
yhu422's avatar
yhu422 committed
55
from vllm.usage.usage_lib import UsageContext
56
from vllm.utils import Counter, Device, as_iter, is_list_of
57
from vllm.v1.engine import EngineCoreRequest
58
from vllm.v1.engine.llm_engine import LLMEngine
59
from vllm.v1.engine.processor import Processor
60
from vllm.v1.sample.logits_processor import LogitsProcessor
61

62
63
64
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

65
66
logger = init_logger(__name__)

67
68
_R = TypeVar("_R", default=Any)

69
70

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
74
75
76
77
78
79
80
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
81
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
82
83
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
84
85
86
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
87
88
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
89
90
91
92
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
93
94
        allowed_media_domains: If set, only media URLs that belong to this 
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
101
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
102
        quantization: The method used to quantize the model weights. Currently,
103
            we support "awq", "gptq", and "fp8" (experimental).
104
105
106
107
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
108
109
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
110
111
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
112
113
114
115
116
117
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
118
119
120
121
122
123
124
125
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
            compared with using gpu_memory_memory_utilization. Note that
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
126
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
127
128
129
130
131
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
132
133
134
135
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
136
137
138
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
139
140
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
141
        hf_token: The token to use as HTTP bearer authorization for remote files
142
            . If `True`, will use the token generated when running
143
            `huggingface-cli login` (stored in `~/.huggingface`).
144
145
146
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
147
148
149
150
151
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
152
153
154
155
156
        pooler_config: Initialize non-default pooling config for the pooling
            model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
        override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
            argument is deprecated and will be removed in v0.12.0 or v1.0.0,
            whichever is sooner.
157
158
159
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
160
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
161

162
163
    Note:
        This class is intended to be used for offline inference. For online
164
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
165
    """
166
167
168
169

    def __init__(
        self,
        model: str,
170
        *,
171
172
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
173
        tokenizer: Optional[str] = None,
174
        tokenizer_mode: TokenizerMode = "auto",
175
        skip_tokenizer_init: bool = False,
176
        trust_remote_code: bool = False,
177
        allowed_local_media_path: str = "",
178
        allowed_media_domains: Optional[list[str]] = None,
179
        tensor_parallel_size: int = 1,
180
181
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
182
        revision: Optional[str] = None,
183
        tokenizer_revision: Optional[str] = None,
184
        seed: Optional[int] = None,
185
        gpu_memory_utilization: float = 0.9,
186
        swap_space: float = 4,
187
        cpu_offload_gb: float = 0,
188
        enforce_eager: bool = False,
189
        disable_custom_all_reduce: bool = False,
190
        hf_token: Optional[Union[bool, str]] = None,
191
        hf_overrides: Optional[HfOverrides] = None,
192
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
193
        pooler_config: Optional[PoolerConfig] = None,
194
        override_pooler_config: Optional[PoolerConfig] = None,
195
196
        structured_outputs_config: Optional[Union[dict[
            str, Any], StructuredOutputsConfig]] = None,
197
        kv_cache_memory_bytes: Optional[int] = None,
198
199
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
200
201
        logits_processors: Optional[list[Union[str,
                                               type[LogitsProcessor]]]] = None,
202
        **kwargs: Any,
203
    ) -> None:
204
        """LLM constructor."""
205

206
207
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
208

209
210
211
212
213
214
215
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

216
217
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
218
            from vllm.config.kv_transfer import KVTransferConfig
219
220
221
222
223
224
225
226
227
228
229
230
231
232
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

233
234
235
        if hf_overrides is None:
            hf_overrides = {}

236
        if compilation_config is not None:
237
238
239
240
241
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
242
243
244
245
246
                    **{
                        k: v
                        for k, v in compilation_config.items()
                        if is_init_field(CompilationConfig, k)
                    })
247
248
            else:
                compilation_config_instance = compilation_config
249
        else:
250
            compilation_config_instance = CompilationConfig()
251

252
253
254
255
256
257
258
259
260
261
262
263
264
        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{
                        k: v
                        for k, v in structured_outputs_config.items()
                        if is_init_field(StructuredOutputsConfig, k)
                    })
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

Zhuohan Li's avatar
Zhuohan Li committed
265
        engine_args = EngineArgs(
266
            model=model,
267
268
            runner=runner,
            convert=convert,
269
            tokenizer=tokenizer,
270
            tokenizer_mode=tokenizer_mode,
271
            skip_tokenizer_init=skip_tokenizer_init,
272
            trust_remote_code=trust_remote_code,
273
            allowed_local_media_path=allowed_local_media_path,
274
            allowed_media_domains=allowed_media_domains,
275
276
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
277
            quantization=quantization,
278
            revision=revision,
279
            tokenizer_revision=tokenizer_revision,
280
281
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
282
            kv_cache_memory_bytes=kv_cache_memory_bytes,
283
            swap_space=swap_space,
284
            cpu_offload_gb=cpu_offload_gb,
285
            enforce_eager=enforce_eager,
286
            disable_custom_all_reduce=disable_custom_all_reduce,
287
            hf_token=hf_token,
288
            hf_overrides=hf_overrides,
289
            mm_processor_kwargs=mm_processor_kwargs,
290
            pooler_config=pooler_config,
291
            override_pooler_config=override_pooler_config,
292
            structured_outputs_config=structured_outputs_instance,
293
            compilation_config=compilation_config_instance,
294
            logits_processors=logits_processors,
295
296
            **kwargs,
        )
297

298
299
        log_non_default_args(engine_args)

300
301
302
303
        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
304

305
        self.request_counter = Counter()
306
        self.default_sampling_params: Union[dict[str, Any], None] = None
307

308
        supported_tasks = self.llm_engine.get_supported_tasks()  # type: ignore
309
310
311
312
313

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

314
315
316
317
318
        # Load the Input/Output processor plugin if any
        io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
        self.io_processor = get_io_processor(self.llm_engine.vllm_config,
                                             io_processor_plugin)

319
320
321
322
    @property
    def model_config(self):
        return self.llm_engine.model_config

323
324
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer()
325
326

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
327
328
329
330
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
331
            self.llm_engine.tokenizer = tokenizer
332
        else:
333
            self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
334

335
336
337
338
339
340
341
342
343
344
    def _get_processor(self) -> Processor:
        if not hasattr(self, "_processor"):
            vllm_config = self.llm_engine.vllm_config
            if self.model_config.skip_tokenizer_init:
                tokenizer = None
            else:
                tokenizer = init_tokenizer_from_configs(self.model_config)
            self._processor = Processor(vllm_config, tokenizer)
        return self._processor

345
    def get_default_sampling_params(self) -> SamplingParams:
346
347
348
349
350
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
351
352
        return SamplingParams()

353
354
355
356
357
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
358
        *,
359
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
360
361
362
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
363
364
        """Generates the completions for the input prompts.

365
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
366
367
368
369
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
370
            prompts: The prompts to the LLM. You may pass a sequence of prompts
371
                for batch inference. See [PromptType][vllm.inputs.PromptType]
372
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
373
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
374
375
376
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
377
                prompts and it is paired one by one with the prompt.
378
379
380
381
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
382
            lora_request: LoRA request to use for generation, if any.
383
384
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
385
386

        Returns:
387
            A list of `RequestOutput` objects containing the
388
            generated completions in the same order as the input prompts.
389

390
391
392
393
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
394
        """
395
396
397
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
398
399
400
401
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model.")
402

403
404
        if sampling_params is None:
            # Use default sampling params.
405
            sampling_params = self.get_default_sampling_params()
406

407
408
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
409
            prompts, lora_request)
410

411
        self._validate_and_add_requests(
412
            prompts=prompts,
413
            params=sampling_params,
414
            use_tqdm=use_tqdm,
415
            lora_request=lora_request,
416
417
            priority=priority,
        )
418

419
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
420
        return self.engine_class.validate_outputs(outputs, RequestOutput)
421

422
    def _get_modality_specific_lora_reqs(
423
            self, prompts: Union[PromptType, Sequence[PromptType]],
424
425
426
427
428
429
430
431
432
433
434
435
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

436
437
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
438

439
        optional_loras = ([lora_request] * len(prompts)
440
441
442
443
444
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
445
                prompt,
446
447
                opt_lora_req,
                lora_config.default_mm_loras,
448
            ) for prompt, opt_lora_req in zip(prompts, optional_loras)
449
450
        ]

451
    def _resolve_single_prompt_mm_lora(self, prompt: PromptType,
452
453
454
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
455
456
        if (not default_mm_loras or not isinstance(prompt, dict)
                or "multi_modal_data" not in prompt):
457
458
            return lora_request

459
        prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
460

461
462
        intersection = set(prompt["multi_modal_data"].keys()) \
            .intersection(default_mm_loras.keys())
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

497
    def collective_rpc(self,
498
                       method: Union[str, Callable[..., _R]],
499
                       timeout: Optional[float] = None,
500
501
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
502
503
504
505
506
507
508
509
510
511
512
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
513
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
514
515
516
517
518
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
519

520
521
522
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
523
        """
524
525

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
526
527

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
528
        """
529
530
        Run a function directly on the model inside each worker,
        returning the result for each of them.
531
532
533
534
535
536

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
537
        """
538
        return self.llm_engine.apply_model(func)
539

540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

556
557
    def beam_search(
        self,
558
        prompts: list[Union[TokensPrompt, TextPrompt]],
559
        params: BeamSearchParams,
560
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
561
        use_tqdm: bool = False,
562
        concurrency_limit: Optional[int] = None,
563
    ) -> list[BeamSearchOutput]:
564
565
566
567
568
569
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
570
            params: The beam search parameters.
571
            lora_request: LoRA request to use for generation, if any.
572
            use_tqdm: Whether to use tqdm to display the progress bar.
573
574
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
575
        """
576
577
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
578
579
580
581
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
582
583
        length_penalty = params.length_penalty

584
585
586
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

587
588
589
590
591
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
592

593
594
595
596
597
598
599
600
601
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
                "Disabling progress bar.")
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

602
603
604
605
606
607
608
609
610
611
612
613
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
614

615
616
617
618
619
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
620
                                            temperature=temperature)
621
        instances: list[BeamSearchInstance] = []
622

623
        for lora_req, prompt in zip(lora_requests, prompts):
624
625
626
627
628
629
630
631
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

632
633
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
634
635
636
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
637

638
            instances.append(
639
640
641
642
643
644
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
645

646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
        for prompt_start in range(0, len(prompts), concurrency_limit):
            instances_batch = instances[prompt_start:prompt_start +
                                        concurrency_limit]

            token_iter = range(max_tokens)
            if use_tqdm:
                token_iter = tqdm(token_iter,
                                  desc="Beam search",
                                  unit="token",
                                  unit_scale=False)
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
                    "reflect instance-level progress.")
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
                    sum((instance.beams for instance in instances_batch), []))
                pos = [0] + list(
                    itertools.accumulate(
                        len(instance.beams) for instance in instances_batch))
                instance_start_and_end: list[tuple[int, int]] = list(
                    zip(pos[:-1], pos[1:]))

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
                    *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                      for beam in all_beams])

                # only runs for one step
                # we don't need to use tqdm here
                output = self.generate(prompts_batch,
                                       sampling_params=beam_search_params,
                                       use_tqdm=False,
                                       lora_request=lora_req_batch)

                for (start, end), instance in zip(instance_start_and_end,
                                                  instances_batch):
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    lora_request=current_beam.lora_request,
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs)

                                if token_id == tokenizer.eos_token_id and \
                                    not ignore_eos:
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
                    sorted_beams = sorted(instance_new_beams,
                                          key=sort_beams_key,
                                          reverse=True)
                    instance.beams = sorted_beams[:beam_width]
719
720
721
722
723

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
724
                                      key=sort_beams_key,
725
726
727
728
729
730
731
732
733
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

734
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
735
        self,
736
737
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
738
        chat_template: Optional[str] = None,
739
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
740
        add_generation_prompt: bool = True,
741
        continue_final_message: bool = False,
742
        tools: Optional[list[dict[str, Any]]] = None,
743
        chat_template_kwargs: Optional[dict[str, Any]] = None,
744
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
745
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
746
        """
747
748
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
749

750
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
751
        Returns:
752
753
754
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
755
        """
756
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
757

758
759
        # Handle multi and single conversations
        if is_list_of(messages, list):
760
761
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
762
                                    messages)
763
        else:
764
            # messages is list[...]
765
            list_of_messages = [
766
                cast(list[ChatCompletionMessageParam], messages)
767
            ]
768

769
        tokenizer = self.get_tokenizer()
770
771
772
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
773
            tools,
774
775
            chat_template_content_format,
            tokenizer,
776
            model_config=model_config,
777
778
        )

779
780
781
782
783
784
785
786
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

787
        prompts: list[TokensPrompt] = []
788
789

        for msgs in list_of_messages:
790
791
792
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
793
            conversation, mm_data, mm_uuids = parse_chat_messages(
794
795
796
797
798
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
799
800

            if isinstance(tokenizer, MistralTokenizer):
801
                prompt_token_ids = apply_mistral_chat_template(
802
803
                    tokenizer,
                    messages=msgs,
804
                    **_chat_template_kwargs,
805
806
                )
            else:
807
                prompt_str = apply_hf_chat_template(
808
                    tokenizer=tokenizer,
809
                    conversation=conversation,
810
                    model_config=model_config,
811
                    **_chat_template_kwargs,
812
                )
813
814
815
816
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
817

818
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
819
820
821
822

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

823
824
825
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

826
827
828
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

829
            prompts.append(prompt)
830

831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
        return prompts

    def chat(
        self,
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
        sampling_params: Optional[Union[SamplingParams,
                                        list[SamplingParams]]] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: Optional[list[dict[str, Any]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
911
        return self.generate(
912
            prompts,
913
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
914
915
916
917
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

918
919
    def encode(
        self,
920
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
921
922
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
923
        *,
924
        truncate_prompt_tokens: Optional[int] = None,
925
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
926
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
927
        pooling_task: PoolingTask = "encode",
928
        tokenization_kwargs: Optional[dict[str, Any]] = None,
929
    ) -> list[PoolingRequestOutput]:
930
931
        """Apply pooling to the hidden states corresponding to the input
        prompts.
932

933
        This class automatically batches the given prompts, considering
934
935
936
937
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
938
            prompts: The prompts to the LLM. You may pass a sequence of prompts
939
                for batch inference. See [PromptType][vllm.inputs.PromptType]
940
                for more details about the format of each prompt.
941
942
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
943
944
945
946
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
947
            lora_request: LoRA request to use for generation, if any.
948
            pooling_task: Override the pooling task to use.
949
950
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
951
952

        Returns:
953
            A list of `PoolingRequestOutput` objects containing the
954
            pooled hidden states in the same order as the input prompts.
955

956
957
958
959
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
960
        """
961
962
963
964

        if self.supported_tasks == ["encode"] and pooling_task is None:
            pooling_task = "encode"

965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
        if pooling_task is None:
            if "embed" in self.supported_tasks:
                pooling_task = "embed"
            else:
                pooling_task = "encode"

            logger.warning_once(
                "`LLM.encode` is currently using `pooling_task = %s`.\n"
                "Please use one of the more specific methods or set the "
                "task directly when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                "or `pooling_task=\"embed\"`.\n"
                "  - For classification logits, use `LLM.classify(...)` "
                "or `pooling_task=\"classify\"`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                "or `pooling_task=\"reward\"`\n"
                "  - For similarity scores, use `LLM.score(...)`.",
                pooling_task)

984
985
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
986
        if runner_type != "pooling":
987
988
989
990
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
991

992
993
994
995
        if pooling_task not in self.supported_tasks:
            raise ValueError(
                f"pooling_task must be one of {self.supported_tasks}.")

996
997
998
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
999

1000
1001
1002
1003
1004
        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens
1005

1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
                    "offline inference example for more details.")

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1022
        self._validate_and_add_requests(
1023
            prompts=prompts,
1024
            params=pooling_params,
1025
            use_tqdm=use_tqdm,
1026
            lora_request=lora_request,
1027
1028
        )

1029
        outputs = self._run_engine(use_tqdm=use_tqdm)
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047

        model_outputs = self.engine_class.validate_outputs(
            outputs, PoolingRequestOutput)

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
                model_output=model_outputs)

            return [
                PoolingRequestOutput[Any](request_id="",
                                          outputs=processed_outputs,
                                          prompt_token_ids=[],
                                          finished=True)
            ]
        else:
            return model_outputs
1048

1049
1050
1051
1052
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1053
        truncate_prompt_tokens: Optional[int] = None,
1054
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1055
1056
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1057
1058
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
1059
1060
1061
1062
1063
1064
1065
1066
1067
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1068
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1069
                for more details about the format of each prompt.
1070
1071
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1072
1073
1074
1075
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1076
1077
1078
            lora_request: LoRA request to use for generation, if any.

        Returns:
1079
            A list of `EmbeddingRequestOutput` objects containing the
1080
1081
            embedding vectors in the same order as the input prompts.
        """
1082
        if "embed" not in self.supported_tasks:
1083
1084
1085
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`.")
1086

1087
1088
1089
1090
1091
1092
1093
1094
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1095
1096
1097
1098
1099
1100
1101

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1102
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1103
1104
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1105
1106
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
1107
1108
1109
1110
1111
1112
1113
1114
1115
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1116
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1117
                for more details about the format of each prompt.
1118
1119
1120
1121
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1122
            lora_request: LoRA request to use for generation, if any.
1123
1124
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1125
        Returns:
1126
            A list of `ClassificationRequestOutput` objects containing the
1127
1128
            embedding vectors in the same order as the input prompts.
        """
1129
        if "classify" not in self.supported_tasks:
1130
            raise ValueError(
1131
                "Classification API is not supported by this model. "
1132
                "Try converting the model using `--convert classify`.")
1133

1134
1135
1136
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1137
            pooling_params=pooling_params,
1138
1139
1140
            lora_request=lora_request,
            pooling_task="classify",
        )
1141
1142
1143

        return [ClassificationRequestOutput.from_base(item) for item in items]

1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
    def reward(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1161
                for more details about the format of each prompt.
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
            pooling_task="encode",
        )

1183
1184
1185
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1186
1187
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1188
        truncate_prompt_tokens: Optional[int] = None,
1189
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1190
        pooling_params: Optional[PoolingParams] = None,
1191
1192
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1193

1194
        encoded_output: list[PoolingRequestOutput] = self.encode(
1195
            text_1 + text_2,
1196
            truncate_prompt_tokens=truncate_prompt_tokens,
1197
1198
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1199
            pooling_params=pooling_params,
1200
1201
            pooling_task="embed",
        )
1202

1203
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1204
            0:len(text_1)]
1205
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1206
            len(text_1):]
1207
1208
1209
1210

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1211
1212
1213
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1214
1215
1216
1217
1218
1219
1220

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1221
        tokenizer: AnyTokenizer,
1222
1223
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1224
        truncate_prompt_tokens: Optional[int] = None,
1225
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1226
        pooling_params: Optional[PoolingParams] = None,
1227
1228
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1229
        model_config = self.llm_engine.model_config
1230
1231
1232

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
1233
                "Score API is not supported for Mistral tokenizer")
1234

1235
1236
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1237

1238
1239
1240
1241
1242
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        model_config = self.llm_engine.model_config
        pooling_params.verify("score", model_config)
1243
        pooling_params_list = list[PoolingParams]()
1244

1245
        tokenization_kwargs: dict[str, Any] = {}
1246
1247

        _validate_truncation_size(model_config.max_model_len,
1248
                                  truncate_prompt_tokens, tokenization_kwargs)
1249

1250
        prompts = list[PromptType]()
1251

1252
1253
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1254
        model_config = self.llm_engine.model_config
1255

1256
1257
1258
1259
1260
1261
1262
1263
1264
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1265
            if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
1266
1267
1268
1269
1270
1271
1272
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1273
            prompts.append(engine_prompt)
1274
1275

        self._validate_and_add_requests(
1276
            prompts=prompts,
1277
            params=pooling_params_list,
1278
            use_tqdm=use_tqdm,
1279
1280
1281
1282
1283
1284
1285
1286
1287
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1288
1289
    def score(
        self,
1290
1291
1292
1293
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1294
        /,
1295
        *,
1296
        truncate_prompt_tokens: Optional[int] = None,
1297
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1298
        pooling_params: Optional[PoolingParams] = None,
1299
1300
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1301
1302
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1303

1304
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1305
1306
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1307
        The input pairs are used to build a list of prompts for the
1308
1309
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1310
1311
1312
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1313
        appropriate multi-modal models. For multi-modal inputs, ensure the
1314
        prompt structure matches the model's expected input format.
1315
1316

        Args:
1317
1318
1319
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1320
                the `data_2` list.
1321
            data_2: The data to pair with the query to form the input to
1322
                the LLM. Can be text or multi-modal data. See [PromptType]
1323
                [vllm.inputs.PromptType] for more details about the format of
1324
                each prompt.
1325
1326
1327
1328
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1329
            lora_request: LoRA request to use for generation, if any.
1330
1331
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1332
        Returns:
1333
            A list of `ScoringRequestOutput` objects containing the
1334
1335
            generated scores in the same order as the input prompts.
        """
1336
1337
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1338
        if runner_type != "pooling":
1339
1340
1341
1342
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1343

1344
1345
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1346
            raise ValueError("Score API is not supported by this model. "
1347
1348
                             "Try converting the model using "
                             "`--convert embed` or `--convert classify`.")
1349

1350
        if (model_config.is_cross_encoder
1351
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1352
            raise ValueError("Score API is only enabled for num_labels == 1.")
1353
1354
1355
1356

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1357
        tokenizer = self.get_tokenizer()
1358

1359
        if not model_config.is_multimodal_model:
1360
1361
1362
1363
1364

            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
1365
1366
                    raise ValueError("ScoreMultiModalParam is not supported "
                                     f"for {model_config.architecture}")
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1407

1408
        if model_config.is_cross_encoder:
1409
1410
1411
1412
1413
1414
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1415
                pooling_params,
1416
                lora_request)
1417
        else:
1418
1419
            return self._embedding_score(
                tokenizer,
1420
1421
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1422
1423
                truncate_prompt_tokens,
                use_tqdm,
1424
                pooling_params,
1425
                lora_request)
1426

1427
1428
1429
1430
1431
1432
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1433
1434
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1435

1436
1437
1438
1439
1440
1441
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1442
        Args:
1443
1444
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1445
                is forgotten. Level 1 sleep is good for sleeping and waking
1446
1447
1448
1449
1450
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1451
                sleep is good for sleeping and waking up the engine to run a
1452
                different model or update the model, where previous model
1453
                weights are not needed. It reduces CPU memory pressure.
1454
        """
1455
        self.reset_prefix_cache()
1456
1457
        self.llm_engine.sleep(level=level)

1458
    def wake_up(self, tags: Optional[list[str]] = None):
1459
        """
1460
1461
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1462

1463
        Args:
1464
1465
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1466
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1467
                wake_up should be called with all tags (or None) before the
1468
1469
1470
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1471

1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1484
1485
    def _validate_and_add_requests(
        self,
1486
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
1487
1488
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1489
        *,
1490
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1491
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1492
        priority: Optional[list[int]] = None,
1493
    ) -> None:
1494
        if isinstance(prompts, (str, dict)):
1495
            # Convert a single prompt to a list.
1496
            prompts = [prompts]  # type: ignore[list-item]
1497

1498
        num_requests = len(prompts)
1499
        if isinstance(params, Sequence) and len(params) != num_requests:
1500
            raise ValueError("The lengths of prompts and params "
1501
                             "must be the same.")
1502
        if isinstance(lora_request,
1503
                      Sequence) and len(lora_request) != num_requests:
1504
1505
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1506

1507
        for sp in params if isinstance(params, Sequence) else (params, ):
1508
1509
1510
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1511

Zhuohan Li's avatar
Zhuohan Li committed
1512
        # Add requests to the engine.
1513
1514
        it = prompts
        if use_tqdm:
1515
1516
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1517
1518

        for i, prompt in enumerate(it):
1519

1520
1521
1522
1523
1524
            if isinstance(prompt, dict):
                self._validate_mm_data_and_uuids(
                    prompt.get("multi_modal_data"),
                    prompt.get("multi_modal_uuids"))

1525
            self._add_request(
1526
                prompt,
1527
                params[i] if isinstance(params, Sequence) else params,
1528
1529
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
1530
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1531
            )
1532

1533
1534
1535
1536
1537
1538
1539
    def _validate_mm_data_and_uuids(
            self,
            multi_modal_data: Optional[Any],  # MultiModalDataDict
            multi_modal_uuids: Optional[Any],  # MultiModalUUIDDict
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1540
        then its corresponding UUID must be set.
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
                        if multi_modal_uuids is None or modality not in multi_modal_uuids or multi_modal_uuids[  # noqa: E501
                                modality] is None:
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
                                f"but UUID is not provided")
                        else:
                            if len(
                                    multi_modal_uuids[modality]
                            ) <= i or multi_modal_uuids[modality][i] is None:
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
                                    f"but UUID is not provided")
            else:
                if data is None and (multi_modal_uuids is None
                                     or modality not in multi_modal_uuids
                                     or multi_modal_uuids[modality] is None):
                    raise ValueError(f"Multi-modal data for {modality} is None"
                                     f" but UUID is not provided")

1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
    def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        *,
        lora_request: Optional[LoRARequest],
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
        """Use the Processor to process inputs for LLMEngine."""
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.model_config.max_model_len,
                                  params.truncate_prompt_tokens,
                                  tokenization_kwargs)

        processor = self._get_processor()
        engine_request = processor.process_inputs(
            request_id,
            engine_prompt,
            params,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1594
    def _add_request(
nunjunj's avatar
nunjunj committed
1595
        self,
1596
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1597
        params: Union[SamplingParams, PoolingParams],
1598
        lora_request: Optional[LoRARequest] = None,
1599
        priority: int = 0,
1600
    ) -> None:
1601
        prompt_text, _, _ = get_prompt_components(prompt)
1602
        request_id = str(next(self.request_counter))
1603
1604

        engine_request, tokenization_kwargs = self._process_inputs(
1605
            request_id,
1606
            prompt,
1607
1608
            params,
            lora_request=lora_request,
1609
1610
1611
1612
1613
1614
1615
1616
            priority=priority,
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1617
            tokenization_kwargs=tokenization_kwargs,
1618
            priority=priority,
1619
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1620
        )
1621

1622
    def _run_engine(
1623
1624
1625
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1626
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1627
1628
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1629
            num_requests = self.llm_engine.get_num_unfinished_requests()
1630
1631
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1632
1633
1634
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1635
1636
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1637
            )
1638

Zhuohan Li's avatar
Zhuohan Li committed
1639
        # Run the engine.
1640
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1641
1642
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1643
1644
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1645
            for output in step_outputs:
1646
                if output.finished:
1647
1648
                    outputs.append(output)
                    if use_tqdm:
1649
1650
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1651
                            n = len(output.outputs)
1652
                            assert output.prompt_token_ids is not None
1653
                            total_in_toks += len(output.prompt_token_ids) * n
1654
1655
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1656
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1657
1658
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1659
1660
1661
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1662
                            pbar.update(n)
1663
1664
                        else:
                            pbar.update(1)
1665
1666
                        if pbar.n == num_requests:
                            pbar.refresh()
1667

1668
1669
        if use_tqdm:
            pbar.close()
1670
1671
1672
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1673
        return sorted(outputs, key=lambda x: int(x.request_id))