llm.py 67 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Sequence
6
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar
13

14
import vllm.envs as envs
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
16
17
                              BeamSearchSequence,
                              create_sort_beams_key_function)
18
19
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
20
21
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
                                   PoolerConfig, RunnerOption)
Joe Runde's avatar
Joe Runde committed
22
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
23
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
24
                                         ChatTemplateContentFormatOption,
25
26
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
27
28
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
29
30
# yapf conflicts with isort for this block
# yapf: disable
31
32
33
34
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
35
                                          compress_token_type_ids,
36
                                          get_score_prompt)
37
# yapf: enable
38
39
from vllm.entrypoints.utils import (_validate_truncation_size,
                                    log_non_default_args)
40
41
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
                         TokensPrompt)
42
from vllm.logger import init_logger
43
from vllm.lora.request import LoRARequest
44
from vllm.model_executor.layers.quantization import QuantizationMethods
45
46
47
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
48
from vllm.plugins.io_processors import get_io_processor
49
from vllm.pooling_params import PoolingParams
50
51
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
                                  SamplingParams)
52
from vllm.tasks import PoolingTask
53
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
54
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
55
from vllm.usage.usage_lib import UsageContext
56
from vllm.utils import Counter, Device, as_iter, is_list_of
57
from vllm.v1.sample.logits_processor import LogitsProcessor
58

59
60
61
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

62
63
logger = init_logger(__name__)

64
65
_R = TypeVar("_R", default=Any)

66
67

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
70
71
72
73
74
75
76
77
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
78
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
79
80
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
81
82
83
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
84
85
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
86
87
88
89
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
95
96
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
97
        quantization: The method used to quantize the model weights. Currently,
98
            we support "awq", "gptq", and "fp8" (experimental).
99
100
101
102
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
103
104
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
105
106
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
107
108
109
110
111
112
113
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
114
115
116
117
118
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
119
120
121
122
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
123
124
125
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
126
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
127
            When a sequence has context length larger than this, we fall back
128
129
130
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
131
132
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
133
134
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
135
        hf_token: The token to use as HTTP bearer authorization for remote files
136
            . If `True`, will use the token generated when running
137
            `huggingface-cli login` (stored in `~/.huggingface`).
138
139
140
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
141
142
143
144
145
146
147
148
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
        override_pooler_config: Initialize non-default pooling config or
            override default pooling config for the pooling model.
            e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
149
150
151
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
152
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
153

154
155
    Note:
        This class is intended to be used for offline inference. For online
156
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
157
    """
158
159
160
161

    def __init__(
        self,
        model: str,
162
        *,
163
164
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
165
        tokenizer: Optional[str] = None,
166
        tokenizer_mode: TokenizerMode = "auto",
167
        skip_tokenizer_init: bool = False,
168
        trust_remote_code: bool = False,
169
        allowed_local_media_path: str = "",
170
        tensor_parallel_size: int = 1,
171
172
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
173
        revision: Optional[str] = None,
174
        tokenizer_revision: Optional[str] = None,
175
        seed: Optional[int] = None,
176
        gpu_memory_utilization: float = 0.9,
177
        swap_space: float = 4,
178
        cpu_offload_gb: float = 0,
179
        enforce_eager: bool = False,
180
        max_seq_len_to_capture: int = 8192,
181
        disable_custom_all_reduce: bool = False,
182
        disable_async_output_proc: bool = False,
183
        hf_token: Optional[Union[bool, str]] = None,
184
        hf_overrides: Optional[HfOverrides] = None,
185
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
186
        override_pooler_config: Optional[PoolerConfig] = None,
187
188
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
189
190
        logits_processors: Optional[list[Union[str,
                                               type[LogitsProcessor]]]] = None,
191
        **kwargs: Any,
192
    ) -> None:
193
        """LLM constructor."""
194

195
196
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
197

198
199
200
201
202
203
204
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

205
206
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
207
            from vllm.config.kv_transfer import KVTransferConfig
208
209
210
211
212
213
214
215
216
217
218
219
220
221
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

222
223
224
        if hf_overrides is None:
            hf_overrides = {}

225
        if compilation_config is not None:
226
227
228
229
230
231
232
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
233
234
            else:
                compilation_config_instance = compilation_config
235
        else:
236
            compilation_config_instance = CompilationConfig()
237

Zhuohan Li's avatar
Zhuohan Li committed
238
        engine_args = EngineArgs(
239
            model=model,
240
241
            runner=runner,
            convert=convert,
242
            tokenizer=tokenizer,
243
            tokenizer_mode=tokenizer_mode,
244
            skip_tokenizer_init=skip_tokenizer_init,
245
            trust_remote_code=trust_remote_code,
246
            allowed_local_media_path=allowed_local_media_path,
247
248
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
249
            quantization=quantization,
250
            revision=revision,
251
            tokenizer_revision=tokenizer_revision,
252
253
254
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
255
            cpu_offload_gb=cpu_offload_gb,
256
            enforce_eager=enforce_eager,
257
            max_seq_len_to_capture=max_seq_len_to_capture,
258
            disable_custom_all_reduce=disable_custom_all_reduce,
259
            disable_async_output_proc=disable_async_output_proc,
260
            hf_token=hf_token,
261
            hf_overrides=hf_overrides,
262
            mm_processor_kwargs=mm_processor_kwargs,
263
            override_pooler_config=override_pooler_config,
264
            compilation_config=compilation_config_instance,
265
            logits_processors=logits_processors,
266
267
            **kwargs,
        )
268

269
270
        log_non_default_args(engine_args)

271
272
273
274
        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
275

276
        self.request_counter = Counter()
277
        self.default_sampling_params: Union[dict[str, Any], None] = None
278

279
280
281
282
283
284
285
286
287
288
        if envs.VLLM_USE_V1:
            supported_tasks = self.llm_engine \
                .get_supported_tasks()  # type: ignore
        else:
            supported_tasks = self.llm_engine.model_config.supported_tasks

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

289
290
291
292
293
        # Load the Input/Output processor plugin if any
        io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
        self.io_processor = get_io_processor(self.llm_engine.vllm_config,
                                             io_processor_plugin)

294
295
296
297
298
299
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
300
301

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
302
        tokenizer_group = self.llm_engine.get_tokenizer_group()
303

304
305
306
307
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
308
            tokenizer_group.tokenizer = tokenizer
309
        else:
310
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
311

312
    def get_default_sampling_params(self) -> SamplingParams:
313
314
315
316
317
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
318
319
        return SamplingParams()

320
321
322
323
324
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
325
        *,
326
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
327
328
329
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
330
331
        """Generates the completions for the input prompts.

332
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
333
334
335
336
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
337
            prompts: The prompts to the LLM. You may pass a sequence of prompts
338
                for batch inference. See [PromptType][vllm.inputs.PromptType]
339
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
340
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
341
342
343
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
344
                prompts and it is paired one by one with the prompt.
345
346
347
348
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
349
            lora_request: LoRA request to use for generation, if any.
350
351
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
352
353

        Returns:
354
            A list of `RequestOutput` objects containing the
355
            generated completions in the same order as the input prompts.
356

357
358
359
360
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
361
        """
362
363
364
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
365
366
367
368
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model.")
369

370
371
        if sampling_params is None:
            # Use default sampling params.
372
            sampling_params = self.get_default_sampling_params()
373

374
375
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
376
            prompts, lora_request)
377

378
        self._validate_and_add_requests(
379
            prompts=prompts,
380
            params=sampling_params,
381
            use_tqdm=use_tqdm,
382
            lora_request=lora_request,
383
384
            priority=priority,
        )
385

386
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
387
        return self.engine_class.validate_outputs(outputs, RequestOutput)
388

389
    def _get_modality_specific_lora_reqs(
390
            self, prompts: Union[PromptType, Sequence[PromptType]],
391
392
393
394
395
396
397
398
399
400
401
402
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

403
404
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
405

406
        optional_loras = ([lora_request] * len(prompts)
407
408
409
410
411
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
412
                prompt,
413
414
                opt_lora_req,
                lora_config.default_mm_loras,
415
            ) for prompt, opt_lora_req in zip(prompts, optional_loras)
416
417
        ]

418
    def _resolve_single_prompt_mm_lora(self, prompt: PromptType,
419
420
421
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
422
423
        if (not default_mm_loras or not isinstance(prompt, dict)
                or "multi_modal_data" not in prompt):
424
425
            return lora_request

426
        prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
427

428
429
        intersection = set(prompt["multi_modal_data"].keys()) \
            .intersection(default_mm_loras.keys())
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

464
    def collective_rpc(self,
465
                       method: Union[str, Callable[..., _R]],
466
                       timeout: Optional[float] = None,
467
468
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
469
470
471
472
473
474
475
476
477
478
479
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
480
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
481
482
483
484
485
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
486

487
488
489
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
490
        """
491
492

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
493
494

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
495
        """
496
497
        Run a function directly on the model inside each worker,
        returning the result for each of them.
498
        """
499
500
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
501

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

518
519
    def beam_search(
        self,
520
        prompts: list[Union[TokensPrompt, TextPrompt]],
521
        params: BeamSearchParams,
522
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
523
        use_tqdm: bool = False,
524
        concurrency_limit: Optional[int] = None,
525
    ) -> list[BeamSearchOutput]:
526
527
528
529
530
531
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
532
            params: The beam search parameters.
533
            lora_request: LoRA request to use for generation, if any.
534
            use_tqdm: Whether to use tqdm to display the progress bar.
535
536
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
537
        """
538
539
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
540
541
542
543
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
544
545
        length_penalty = params.length_penalty

546
547
548
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

549
550
551
552
553
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
554

555
556
557
558
559
560
561
562
563
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
                "Disabling progress bar.")
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

564
565
566
567
568
569
570
571
572
573
574
575
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
576

577
578
579
580
581
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
582
                                            temperature=temperature)
583
        instances: list[BeamSearchInstance] = []
584

585
        for lora_req, prompt in zip(lora_requests, prompts):
586
587
588
589
590
591
592
593
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

594
595
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
596
597
598
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
599

600
            instances.append(
601
602
603
604
605
606
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
607

608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
        for prompt_start in range(0, len(prompts), concurrency_limit):
            instances_batch = instances[prompt_start:prompt_start +
                                        concurrency_limit]

            token_iter = range(max_tokens)
            if use_tqdm:
                token_iter = tqdm(token_iter,
                                  desc="Beam search",
                                  unit="token",
                                  unit_scale=False)
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
                    "reflect instance-level progress.")
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
                    sum((instance.beams for instance in instances_batch), []))
                pos = [0] + list(
                    itertools.accumulate(
                        len(instance.beams) for instance in instances_batch))
                instance_start_and_end: list[tuple[int, int]] = list(
                    zip(pos[:-1], pos[1:]))

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
                    *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                      for beam in all_beams])

                # only runs for one step
                # we don't need to use tqdm here
                output = self.generate(prompts_batch,
                                       sampling_params=beam_search_params,
                                       use_tqdm=False,
                                       lora_request=lora_req_batch)

                for (start, end), instance in zip(instance_start_and_end,
                                                  instances_batch):
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    lora_request=current_beam.lora_request,
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs)

                                if token_id == tokenizer.eos_token_id and \
                                    not ignore_eos:
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
                    sorted_beams = sorted(instance_new_beams,
                                          key=sort_beams_key,
                                          reverse=True)
                    instance.beams = sorted_beams[:beam_width]
681
682
683
684
685

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
686
                                      key=sort_beams_key,
687
688
689
690
691
692
693
694
695
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
696
697
    def chat(
        self,
698
699
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
700
        sampling_params: Optional[Union[SamplingParams,
701
                                        list[SamplingParams]]] = None,
702
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
nunjunj's avatar
nunjunj committed
703
704
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
705
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
706
        add_generation_prompt: bool = True,
707
        continue_final_message: bool = False,
708
        tools: Optional[list[dict[str, Any]]] = None,
709
        chat_template_kwargs: Optional[dict[str, Any]] = None,
710
711
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
712
        """
713
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
714

715
        The chat conversation is converted into a text prompt using the
716
717
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.
718
719
720

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
721
722

        Args:
723
724
            messages: A list of conversations or a single conversation.

725
726
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
727

nunjunj's avatar
nunjunj committed
728
729
730
731
732
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
733
734
735
736
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
nunjunj's avatar
nunjunj committed
737
738
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
739
                If not provided, the model's default chat template will be used.
740
741
            chat_template_content_format: The format to render message content.

742
743
744
745
746
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
747

748
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
749
                to each message.
750
            continue_final_message: If True, continues the final message in
751
                the conversation instead of starting a new one. Cannot be
752
                `True` if `add_generation_prompt` is also `True`.
753
754
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
755
756
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
757
758

        Returns:
759
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
760
761
            responses in the same order as the input messages.
        """
762
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
763

764
765
        # Handle multi and single conversations
        if is_list_of(messages, list):
766
767
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
768
                                    messages)
769
        else:
770
            # messages is list[...]
771
            list_of_messages = [
772
                cast(list[ChatCompletionMessageParam], messages)
773
            ]
774

775
        tokenizer = self.get_tokenizer(lora_request)
776
777
778
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
779
            tools,
780
781
            chat_template_content_format,
            tokenizer,
782
            model_config=model_config,
783
784
        )

785
786
787
788
789
790
791
792
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

793
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
794
795

        for msgs in list_of_messages:
796
797
798
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
799
            conversation, mm_data, mm_uuids = parse_chat_messages(
800
801
802
803
804
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
805
806

            if isinstance(tokenizer, MistralTokenizer):
807
                prompt_token_ids = apply_mistral_chat_template(
808
809
                    tokenizer,
                    messages=msgs,
810
                    **_chat_template_kwargs,
811
812
                )
            else:
813
                prompt_str = apply_hf_chat_template(
814
                    tokenizer=tokenizer,
815
                    conversation=conversation,
816
                    model_config=model_config,
817
                    **_chat_template_kwargs,
818
                )
819
820
821
822
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
823

824
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
825
826
827
828

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

829
830
831
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

832
833
834
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

835
            prompts.append(prompt)
836

nunjunj's avatar
nunjunj committed
837
        return self.generate(
838
            prompts,
839
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
840
841
842
843
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

844
845
    def encode(
        self,
846
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
847
848
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
849
        *,
850
        truncate_prompt_tokens: Optional[int] = None,
851
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
852
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
853
        pooling_task: PoolingTask = "encode",
854
        tokenization_kwargs: Optional[dict[str, Any]] = None,
855
    ) -> list[PoolingRequestOutput]:
856
857
        """Apply pooling to the hidden states corresponding to the input
        prompts.
858

859
        This class automatically batches the given prompts, considering
860
861
862
863
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
864
            prompts: The prompts to the LLM. You may pass a sequence of prompts
865
                for batch inference. See [PromptType][vllm.inputs.PromptType]
866
                for more details about the format of each prompt.
867
868
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
869
870
871
872
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
873
            lora_request: LoRA request to use for generation, if any.
874
            pooling_task: Override the pooling task to use.
875
876
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
877
878

        Returns:
879
            A list of `PoolingRequestOutput` objects containing the
880
            pooled hidden states in the same order as the input prompts.
881

882
883
884
885
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
886
        """
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
        if pooling_task is None:
            if "embed" in self.supported_tasks:
                pooling_task = "embed"
            else:
                pooling_task = "encode"

            logger.warning_once(
                "`LLM.encode` is currently using `pooling_task = %s`.\n"
                "Please use one of the more specific methods or set the "
                "task directly when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                "or `pooling_task=\"embed\"`.\n"
                "  - For classification logits, use `LLM.classify(...)` "
                "or `pooling_task=\"classify\"`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                "or `pooling_task=\"reward\"`\n"
                "  - For similarity scores, use `LLM.score(...)`.",
                pooling_task)

906
907
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
908
        if runner_type != "pooling":
909
910
911
912
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
913

914
915
916
917
        if pooling_task not in self.supported_tasks:
            raise ValueError(
                f"pooling_task must be one of {self.supported_tasks}.")

918
919
920
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
921

922
923
924
925
926
        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens
927

928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
                    "offline inference example for more details.")

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

944
        self._validate_and_add_requests(
945
            prompts=prompts,
946
            params=pooling_params,
947
            use_tqdm=use_tqdm,
948
            lora_request=lora_request,
949
950
        )

951
        outputs = self._run_engine(use_tqdm=use_tqdm)
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969

        model_outputs = self.engine_class.validate_outputs(
            outputs, PoolingRequestOutput)

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
                model_output=model_outputs)

            return [
                PoolingRequestOutput[Any](request_id="",
                                          outputs=processed_outputs,
                                          prompt_token_ids=[],
                                          finished=True)
            ]
        else:
            return model_outputs
970

971
972
973
974
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
975
        truncate_prompt_tokens: Optional[int] = None,
976
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
977
978
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
979
980
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
981
982
983
984
985
986
987
988
989
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
990
                for batch inference. See [PromptType][vllm.inputs.PromptType]
991
                for more details about the format of each prompt.
992
993
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
994
995
996
997
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
998
999
1000
            lora_request: LoRA request to use for generation, if any.

        Returns:
1001
            A list of `EmbeddingRequestOutput` objects containing the
1002
1003
            embedding vectors in the same order as the input prompts.
        """
1004
        if "embed" not in self.supported_tasks:
1005
1006
1007
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`.")
1008

1009
1010
1011
1012
1013
1014
1015
1016
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1017
1018
1019
1020
1021
1022
1023

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1024
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1025
1026
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1027
1028
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
1029
1030
1031
1032
1033
1034
1035
1036
1037
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1038
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1039
                for more details about the format of each prompt.
1040
1041
1042
1043
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1044
            lora_request: LoRA request to use for generation, if any.
1045
1046
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1047
        Returns:
1048
            A list of `ClassificationRequestOutput` objects containing the
1049
1050
            embedding vectors in the same order as the input prompts.
        """
1051
        if "classify" not in self.supported_tasks:
1052
            raise ValueError(
1053
                "Classification API is not supported by this model. "
1054
                "Try converting the model using `--convert classify`.")
1055

1056
1057
1058
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1059
            pooling_params=pooling_params,
1060
1061
1062
            lora_request=lora_request,
            pooling_task="classify",
        )
1063
1064
1065

        return [ClassificationRequestOutput.from_base(item) for item in items]

1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
    def reward(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1083
                for more details about the format of each prompt.
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
            pooling_task="encode",
        )

1105
1106
1107
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1108
1109
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1110
        truncate_prompt_tokens: Optional[int] = None,
1111
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1112
        pooling_params: Optional[PoolingParams] = None,
1113
1114
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1115

1116
        encoded_output: list[PoolingRequestOutput] = self.encode(
1117
            text_1 + text_2,
1118
            truncate_prompt_tokens=truncate_prompt_tokens,
1119
1120
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1121
            pooling_params=pooling_params,
1122
1123
            pooling_task="embed",
        )
1124

1125
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1126
            0:len(text_1)]
1127
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1128
            len(text_1):]
1129
1130
1131
1132

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1133
1134
1135
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1136
1137
1138
1139
1140
1141
1142

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1143
        tokenizer: AnyTokenizer,
1144
1145
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1146
        truncate_prompt_tokens: Optional[int] = None,
1147
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1148
        pooling_params: Optional[PoolingParams] = None,
1149
1150
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1151
        model_config = self.llm_engine.model_config
1152
1153
1154

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
1155
                "Score API is not supported for Mistral tokenizer")
1156

1157
1158
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1159

1160
1161
1162
1163
1164
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        model_config = self.llm_engine.model_config
        pooling_params.verify("score", model_config)
1165
        pooling_params_list = list[PoolingParams]()
1166

1167
        tokenization_kwargs: dict[str, Any] = {}
1168
1169

        _validate_truncation_size(model_config.max_model_len,
1170
                                  truncate_prompt_tokens, tokenization_kwargs)
1171

1172
        prompts = list[PromptType]()
1173

1174
1175
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1176
        model_config = self.llm_engine.model_config
1177

1178
1179
1180
1181
1182
1183
1184
1185
1186
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1187
            if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
1188
1189
1190
1191
1192
1193
1194
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1195
            prompts.append(engine_prompt)
1196
1197

        self._validate_and_add_requests(
1198
            prompts=prompts,
1199
            params=pooling_params_list,
1200
            use_tqdm=use_tqdm,
1201
1202
1203
1204
1205
1206
1207
1208
1209
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1210
1211
    def score(
        self,
1212
1213
1214
1215
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1216
        /,
1217
        *,
1218
        truncate_prompt_tokens: Optional[int] = None,
1219
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1220
        pooling_params: Optional[PoolingParams] = None,
1221
1222
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1223
1224
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1225

1226
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1227
1228
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1229
        The input pairs are used to build a list of prompts for the
1230
1231
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1232
1233
1234
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1235
        appropriate multi-modal models. For multi-modal inputs, ensure the
1236
        prompt structure matches the model's expected input format.
1237
1238

        Args:
1239
1240
1241
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1242
                the `data_2` list.
1243
            data_2: The data to pair with the query to form the input to
1244
                the LLM. Can be text or multi-modal data. See [PromptType]
1245
                [vllm.inputs.PromptType] for more details about the format of
1246
                each prompt.
1247
1248
1249
1250
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1251
            lora_request: LoRA request to use for generation, if any.
1252
1253
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1254
        Returns:
1255
            A list of `ScoringRequestOutput` objects containing the
1256
1257
            generated scores in the same order as the input prompts.
        """
1258
1259
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1260
        if runner_type != "pooling":
1261
1262
1263
1264
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1265

1266
1267
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1268
            raise ValueError("Score API is not supported by this model. "
1269
1270
                             "Try converting the model using "
                             "`--convert embed` or `--convert classify`.")
1271

1272
        if (model_config.is_cross_encoder
1273
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1274
            raise ValueError("Score API is only enabled for num_labels == 1.")
1275
1276
1277
1278

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1279
        tokenizer = self.get_tokenizer()
1280

1281
        if not model_config.is_multimodal_model:
1282
1283
1284
1285
1286

            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
1287
1288
                    raise ValueError("ScoreMultiModalParam is not supported "
                                     f"for {model_config.architecture}")
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1329

1330
        if model_config.is_cross_encoder:
1331
1332
1333
1334
1335
1336
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1337
                pooling_params,
1338
                lora_request)
1339
        else:
1340
1341
            return self._embedding_score(
                tokenizer,
1342
1343
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1344
1345
                truncate_prompt_tokens,
                use_tqdm,
1346
                pooling_params,
1347
                lora_request)
1348

1349
1350
1351
1352
1353
1354
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1355
1356
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1357

1358
1359
1360
1361
1362
1363
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1364
        Args:
1365
1366
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1367
                is forgotten. Level 1 sleep is good for sleeping and waking
1368
1369
1370
1371
1372
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1373
                sleep is good for sleeping and waking up the engine to run a
1374
                different model or update the model, where previous model
1375
                weights are not needed. It reduces CPU memory pressure.
1376
        """
1377
        self.reset_prefix_cache()
1378
1379
        self.llm_engine.sleep(level=level)

1380
    def wake_up(self, tags: Optional[list[str]] = None):
1381
        """
1382
1383
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1384

1385
        Args:
1386
1387
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1388
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1389
                wake_up should be called with all tags (or None) before the
1390
1391
1392
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1393

1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1408
1409
    def _validate_and_add_requests(
        self,
1410
        prompts: Union[PromptType, Sequence[PromptType]],
1411
1412
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1413
        *,
1414
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1415
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1416
        priority: Optional[list[int]] = None,
1417
    ) -> None:
1418
        if isinstance(prompts, (str, dict)):
1419
            # Convert a single prompt to a list.
1420
            prompts = [prompts]
1421

1422
        num_requests = len(prompts)
1423
        if isinstance(params, Sequence) and len(params) != num_requests:
1424
            raise ValueError("The lengths of prompts and params "
1425
                             "must be the same.")
1426
        if isinstance(lora_request,
1427
                      Sequence) and len(lora_request) != num_requests:
1428
1429
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1430

1431
        for sp in params if isinstance(params, Sequence) else (params, ):
1432
1433
1434
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1435

Zhuohan Li's avatar
Zhuohan Li committed
1436
        # Add requests to the engine.
1437
1438
        it = prompts
        if use_tqdm:
1439
1440
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1441

1442
1443
        model_config = self.llm_engine.model_config

1444
        for i, prompt in enumerate(it):
1445
1446
1447
1448
1449
1450
1451
1452

            param = params[i] if isinstance(params, Sequence) else params

            tokenization_kwargs: dict[str, Any] = {}
            _validate_truncation_size(model_config.max_model_len,
                                      param.truncate_prompt_tokens,
                                      tokenization_kwargs)

1453
            self._add_request(
1454
                prompt,
1455
                params[i] if isinstance(params, Sequence) else params,
1456
                tokenization_kwargs=tokenization_kwargs,
1457
1458
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
1459
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1460
            )
1461

1462
    def _add_request(
nunjunj's avatar
nunjunj committed
1463
        self,
1464
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1465
        params: Union[SamplingParams, PoolingParams],
1466
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1467
        lora_request: Optional[LoRARequest] = None,
1468
        priority: int = 0,
1469
1470
    ) -> None:
        request_id = str(next(self.request_counter))
1471
1472
        self.llm_engine.add_request(
            request_id,
1473
            prompt,
1474
1475
            params,
            lora_request=lora_request,
1476
            tokenization_kwargs=tokenization_kwargs,
1477
            priority=priority,
nunjunj's avatar
nunjunj committed
1478
        )
1479

1480
    def _run_engine(
1481
1482
1483
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1484
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1485
1486
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1487
            num_requests = self.llm_engine.get_num_unfinished_requests()
1488
1489
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1490
1491
1492
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1493
1494
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1495
            )
1496

Zhuohan Li's avatar
Zhuohan Li committed
1497
        # Run the engine.
1498
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1499
1500
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1501
1502
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1503
            for output in step_outputs:
1504
                if output.finished:
1505
1506
                    outputs.append(output)
                    if use_tqdm:
1507
1508
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1509
                            n = len(output.outputs)
1510
                            assert output.prompt_token_ids is not None
1511
                            total_in_toks += len(output.prompt_token_ids) * n
1512
1513
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1514
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1515
1516
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1517
1518
1519
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1520
                            pbar.update(n)
1521
1522
                        else:
                            pbar.update(1)
1523
1524
                        if pbar.n == num_requests:
                            pbar.refresh()
1525

1526
1527
        if use_tqdm:
            pbar.close()
1528
1529
1530
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1531
        return sorted(outputs, key=lambda x: int(x.request_id))