llm.py 65.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Sequence
6
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar
13

14
import vllm.envs as envs
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
16
17
                              BeamSearchSequence,
                              create_sort_beams_key_function)
18
19
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
20
21
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
                                   PoolerConfig, RunnerOption)
Joe Runde's avatar
Joe Runde committed
22
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
23
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
24
                                         ChatTemplateContentFormatOption,
25
26
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
27
28
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
29
30
# yapf conflicts with isort for this block
# yapf: disable
31
32
33
34
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
35
                                          compress_token_type_ids,
36
                                          get_score_prompt)
37
# yapf: enable
38
39
from vllm.entrypoints.utils import (_validate_truncation_size,
                                    log_non_default_args)
40
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
41
from vllm.logger import init_logger
42
from vllm.lora.request import LoRARequest
43
from vllm.model_executor.layers.quantization import QuantizationMethods
44
45
46
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
47
from vllm.pooling_params import PoolingParams
48
49
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
                                  SamplingParams)
50
from vllm.tasks import PoolingTask
51
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
52
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
53
from vllm.usage.usage_lib import UsageContext
54
from vllm.utils import Counter, Device, as_iter, is_list_of
55
from vllm.v1.sample.logits_processor import LogitsProcessor
56

57
58
59
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

60
61
logger = init_logger(__name__)

62
63
_R = TypeVar("_R", default=Any)

64
65

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
68
69
70
71
72
73
74
75
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
76
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
77
78
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
79
80
81
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
82
83
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
84
85
86
87
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
90
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
94
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
95
        quantization: The method used to quantize the model weights. Currently,
96
            we support "awq", "gptq", and "fp8" (experimental).
97
98
99
100
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
101
102
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
103
104
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
105
106
107
108
109
110
111
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
112
113
114
115
116
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
117
118
119
120
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
121
122
123
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
124
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
125
            When a sequence has context length larger than this, we fall back
126
127
128
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
129
130
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
131
132
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
133
        hf_token: The token to use as HTTP bearer authorization for remote files
134
            . If `True`, will use the token generated when running
135
            `huggingface-cli login` (stored in `~/.huggingface`).
136
137
138
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
139
140
141
142
143
144
145
146
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
        override_pooler_config: Initialize non-default pooling config or
            override default pooling config for the pooling model.
            e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
147
148
149
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
150
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
151

152
153
    Note:
        This class is intended to be used for offline inference. For online
154
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
155
    """
156
157
158
159

    def __init__(
        self,
        model: str,
160
        *,
161
162
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
163
        tokenizer: Optional[str] = None,
164
        tokenizer_mode: TokenizerMode = "auto",
165
        skip_tokenizer_init: bool = False,
166
        trust_remote_code: bool = False,
167
        allowed_local_media_path: str = "",
168
        tensor_parallel_size: int = 1,
169
170
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
171
        revision: Optional[str] = None,
172
        tokenizer_revision: Optional[str] = None,
173
        seed: Optional[int] = None,
174
        gpu_memory_utilization: float = 0.9,
175
        swap_space: float = 4,
176
        cpu_offload_gb: float = 0,
177
        enforce_eager: bool = False,
178
        max_seq_len_to_capture: int = 8192,
179
        disable_custom_all_reduce: bool = False,
180
        disable_async_output_proc: bool = False,
181
        hf_token: Optional[Union[bool, str]] = None,
182
        hf_overrides: Optional[HfOverrides] = None,
183
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
184
        override_pooler_config: Optional[PoolerConfig] = None,
185
186
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
187
188
        logits_processors: Optional[list[Union[str,
                                               type[LogitsProcessor]]]] = None,
189
        **kwargs: Any,
190
    ) -> None:
191
        """LLM constructor."""
192

193
194
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
195

196
197
198
199
200
201
202
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
            from vllm.config import KVTransferConfig
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

220
221
222
        if hf_overrides is None:
            hf_overrides = {}

223
        if compilation_config is not None:
224
225
226
227
228
229
230
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
231
232
            else:
                compilation_config_instance = compilation_config
233
        else:
234
            compilation_config_instance = CompilationConfig()
235

Zhuohan Li's avatar
Zhuohan Li committed
236
        engine_args = EngineArgs(
237
            model=model,
238
239
            runner=runner,
            convert=convert,
240
            tokenizer=tokenizer,
241
            tokenizer_mode=tokenizer_mode,
242
            skip_tokenizer_init=skip_tokenizer_init,
243
            trust_remote_code=trust_remote_code,
244
            allowed_local_media_path=allowed_local_media_path,
245
246
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
247
            quantization=quantization,
248
            revision=revision,
249
            tokenizer_revision=tokenizer_revision,
250
251
252
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
253
            cpu_offload_gb=cpu_offload_gb,
254
            enforce_eager=enforce_eager,
255
            max_seq_len_to_capture=max_seq_len_to_capture,
256
            disable_custom_all_reduce=disable_custom_all_reduce,
257
            disable_async_output_proc=disable_async_output_proc,
258
            hf_token=hf_token,
259
            hf_overrides=hf_overrides,
260
            mm_processor_kwargs=mm_processor_kwargs,
261
            override_pooler_config=override_pooler_config,
262
            compilation_config=compilation_config_instance,
263
            logits_processors=logits_processors,
264
265
            **kwargs,
        )
266

267
268
        log_non_default_args(engine_args)

269
270
271
272
        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
273

274
        self.request_counter = Counter()
275
        self.default_sampling_params: Union[dict[str, Any], None] = None
276

277
278
279
280
281
282
283
284
285
286
        if envs.VLLM_USE_V1:
            supported_tasks = self.llm_engine \
                .get_supported_tasks()  # type: ignore
        else:
            supported_tasks = self.llm_engine.model_config.supported_tasks

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

287
288
289
290
291
292
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
293
294

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
295
        tokenizer_group = self.llm_engine.get_tokenizer_group()
296

297
298
299
300
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
301
            tokenizer_group.tokenizer = tokenizer
302
        else:
303
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
304

305
    def get_default_sampling_params(self) -> SamplingParams:
306
307
308
309
310
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
311
312
        return SamplingParams()

313
314
315
316
317
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
318
        *,
319
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
320
321
322
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
323
324
        """Generates the completions for the input prompts.

325
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
326
327
328
329
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
330
            prompts: The prompts to the LLM. You may pass a sequence of prompts
331
                for batch inference. See [PromptType][vllm.inputs.PromptType]
332
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
333
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
334
335
336
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
337
                prompts and it is paired one by one with the prompt.
338
339
340
341
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
342
            lora_request: LoRA request to use for generation, if any.
343
344
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
345
346

        Returns:
347
            A list of `RequestOutput` objects containing the
348
            generated completions in the same order as the input prompts.
349

350
351
352
353
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
354
        """
355
356
357
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
358
359
360
361
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model.")
362

363
364
        if sampling_params is None:
            # Use default sampling params.
365
            sampling_params = self.get_default_sampling_params()
366

367
368
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
369
            prompts, lora_request)
370

371
        self._validate_and_add_requests(
372
            prompts=prompts,
373
            params=sampling_params,
374
            use_tqdm=use_tqdm,
375
            lora_request=lora_request,
376
377
            priority=priority,
        )
378

379
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
380
        return self.engine_class.validate_outputs(outputs, RequestOutput)
381

382
    def _get_modality_specific_lora_reqs(
383
            self, prompts: Union[PromptType, Sequence[PromptType]],
384
385
386
387
388
389
390
391
392
393
394
395
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

396
397
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
398

399
        optional_loras = ([lora_request] * len(prompts)
400
401
402
403
404
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
405
                prompt,
406
407
                opt_lora_req,
                lora_config.default_mm_loras,
408
            ) for prompt, opt_lora_req in zip(prompts, optional_loras)
409
410
        ]

411
    def _resolve_single_prompt_mm_lora(self, prompt: PromptType,
412
413
414
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
415
416
        if (not default_mm_loras or not isinstance(prompt, dict)
                or "multi_modal_data" not in prompt):
417
418
            return lora_request

419
        prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
420

421
422
        intersection = set(prompt["multi_modal_data"].keys()) \
            .intersection(default_mm_loras.keys())
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

457
    def collective_rpc(self,
458
                       method: Union[str, Callable[..., _R]],
459
                       timeout: Optional[float] = None,
460
461
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
462
463
464
465
466
467
468
469
470
471
472
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
473
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
474
475
476
477
478
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
479

480
481
482
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
483
        """
484
485

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
486
487

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
488
        """
489
490
        Run a function directly on the model inside each worker,
        returning the result for each of them.
491
        """
492
493
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
494

495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

511
512
    def beam_search(
        self,
513
        prompts: list[Union[TokensPrompt, TextPrompt]],
514
        params: BeamSearchParams,
515
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
516
        use_tqdm: bool = False,
517
        concurrency_limit: Optional[int] = None,
518
    ) -> list[BeamSearchOutput]:
519
520
521
522
523
524
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
525
            params: The beam search parameters.
526
            lora_request: LoRA request to use for generation, if any.
527
            use_tqdm: Whether to use tqdm to display the progress bar.
528
529
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
530
        """
531
532
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
533
534
535
536
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
537
538
        length_penalty = params.length_penalty

539
540
541
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

542
543
544
545
546
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
547

548
549
550
551
552
553
554
555
556
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
                "Disabling progress bar.")
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

557
558
559
560
561
562
563
564
565
566
567
568
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
569

570
571
572
573
574
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
575
                                            temperature=temperature)
576
        instances: list[BeamSearchInstance] = []
577

578
        for lora_req, prompt in zip(lora_requests, prompts):
579
580
581
582
583
584
585
586
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

587
588
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
589
590
591
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
592

593
            instances.append(
594
595
596
597
598
599
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
600

601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
        for prompt_start in range(0, len(prompts), concurrency_limit):
            instances_batch = instances[prompt_start:prompt_start +
                                        concurrency_limit]

            token_iter = range(max_tokens)
            if use_tqdm:
                token_iter = tqdm(token_iter,
                                  desc="Beam search",
                                  unit="token",
                                  unit_scale=False)
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
                    "reflect instance-level progress.")
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
                    sum((instance.beams for instance in instances_batch), []))
                pos = [0] + list(
                    itertools.accumulate(
                        len(instance.beams) for instance in instances_batch))
                instance_start_and_end: list[tuple[int, int]] = list(
                    zip(pos[:-1], pos[1:]))

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
                    *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                      for beam in all_beams])

                # only runs for one step
                # we don't need to use tqdm here
                output = self.generate(prompts_batch,
                                       sampling_params=beam_search_params,
                                       use_tqdm=False,
                                       lora_request=lora_req_batch)

                for (start, end), instance in zip(instance_start_and_end,
                                                  instances_batch):
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    lora_request=current_beam.lora_request,
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs)

                                if token_id == tokenizer.eos_token_id and \
                                    not ignore_eos:
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
                    sorted_beams = sorted(instance_new_beams,
                                          key=sort_beams_key,
                                          reverse=True)
                    instance.beams = sorted_beams[:beam_width]
674
675
676
677
678

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
679
                                      key=sort_beams_key,
680
681
682
683
684
685
686
687
688
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
689
690
    def chat(
        self,
691
692
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
693
        sampling_params: Optional[Union[SamplingParams,
694
                                        list[SamplingParams]]] = None,
695
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
nunjunj's avatar
nunjunj committed
696
697
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
698
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
699
        add_generation_prompt: bool = True,
700
        continue_final_message: bool = False,
701
        tools: Optional[list[dict[str, Any]]] = None,
702
        chat_template_kwargs: Optional[dict[str, Any]] = None,
703
704
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
705
        """
706
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
707

708
        The chat conversation is converted into a text prompt using the
709
710
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.
711
712
713

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
714
715

        Args:
716
717
            messages: A list of conversations or a single conversation.

718
719
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
720

nunjunj's avatar
nunjunj committed
721
722
723
724
725
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
726
727
728
729
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
nunjunj's avatar
nunjunj committed
730
731
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
732
                If not provided, the model's default chat template will be used.
733
734
            chat_template_content_format: The format to render message content.

735
736
737
738
739
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
740

741
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
742
                to each message.
743
            continue_final_message: If True, continues the final message in
744
                the conversation instead of starting a new one. Cannot be
745
                `True` if `add_generation_prompt` is also `True`.
746
747
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
748
749
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
750
751

        Returns:
752
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
753
754
            responses in the same order as the input messages.
        """
755
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
756

757
758
        # Handle multi and single conversations
        if is_list_of(messages, list):
759
760
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
761
                                    messages)
762
        else:
763
            # messages is list[...]
764
            list_of_messages = [
765
                cast(list[ChatCompletionMessageParam], messages)
766
            ]
767

768
        tokenizer = self.get_tokenizer(lora_request)
769
770
771
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
772
            tools,
773
774
            chat_template_content_format,
            tokenizer,
775
            model_config=model_config,
776
777
        )

778
779
780
781
782
783
784
785
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

786
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
787
788

        for msgs in list_of_messages:
789
790
791
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
792
            conversation, mm_data = parse_chat_messages(
793
794
795
796
797
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
798
799

            if isinstance(tokenizer, MistralTokenizer):
800
                prompt_token_ids = apply_mistral_chat_template(
801
802
                    tokenizer,
                    messages=msgs,
803
                    **_chat_template_kwargs,
804
805
                )
            else:
806
                prompt_str = apply_hf_chat_template(
807
                    tokenizer=tokenizer,
808
                    conversation=conversation,
809
                    model_config=model_config,
810
                    **_chat_template_kwargs,
811
                )
812
813
814
815
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
816

817
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
818
819
820
821

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

822
823
824
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

825
            prompts.append(prompt)
826

nunjunj's avatar
nunjunj committed
827
        return self.generate(
828
            prompts,
829
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
830
831
832
833
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

834
835
836
837
838
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
839
        *,
840
        truncate_prompt_tokens: Optional[int] = None,
841
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
842
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
843
        pooling_task: PoolingTask = "encode",
844
        tokenization_kwargs: Optional[dict[str, Any]] = None,
845
    ) -> list[PoolingRequestOutput]:
846
847
        """Apply pooling to the hidden states corresponding to the input
        prompts.
848

849
        This class automatically batches the given prompts, considering
850
851
852
853
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
854
            prompts: The prompts to the LLM. You may pass a sequence of prompts
855
                for batch inference. See [PromptType][vllm.inputs.PromptType]
856
                for more details about the format of each prompts.
857
858
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
859
860
861
862
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
863
            lora_request: LoRA request to use for generation, if any.
864
            pooling_task: Override the pooling task to use.
865
866
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
867
868

        Returns:
869
            A list of `PoolingRequestOutput` objects containing the
870
            pooled hidden states in the same order as the input prompts.
871

872
873
874
875
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
876
        """
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
        if pooling_task is None:
            if "embed" in self.supported_tasks:
                pooling_task = "embed"
            else:
                pooling_task = "encode"

            logger.warning_once(
                "`LLM.encode` is currently using `pooling_task = %s`.\n"
                "Please use one of the more specific methods or set the "
                "task directly when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                "or `pooling_task=\"embed\"`.\n"
                "  - For classification logits, use `LLM.classify(...)` "
                "or `pooling_task=\"classify\"`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                "or `pooling_task=\"reward\"`\n"
                "  - For similarity scores, use `LLM.score(...)`.",
                pooling_task)

896
897
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
898
        if runner_type != "pooling":
899
900
901
902
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
903

904
905
906
907
        if pooling_task not in self.supported_tasks:
            raise ValueError(
                f"pooling_task must be one of {self.supported_tasks}.")

908
909
910
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
911

912
913
914
915
916
        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens
917

918
        self._validate_and_add_requests(
919
            prompts=prompts,
920
            params=pooling_params,
921
            use_tqdm=use_tqdm,
922
            lora_request=lora_request,
923
924
        )

925
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
926
        return self.engine_class.validate_outputs(outputs,
927
                                                  PoolingRequestOutput)
928

929
930
931
932
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
933
        truncate_prompt_tokens: Optional[int] = None,
934
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
935
936
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
937
938
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
939
940
941
942
943
944
945
946
947
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
948
                for batch inference. See [PromptType][vllm.inputs.PromptType]
949
                for more details about the format of each prompts.
950
951
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
952
953
954
955
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
956
957
958
            lora_request: LoRA request to use for generation, if any.

        Returns:
959
            A list of `EmbeddingRequestOutput` objects containing the
960
961
            embedding vectors in the same order as the input prompts.
        """
962
        if "embed" not in self.supported_tasks:
963
964
965
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`.")
966

967
968
969
970
971
972
973
974
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
975
976
977
978
979
980
981

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
982
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
983
984
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
985
986
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
987
988
989
990
991
992
993
994
995
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
996
                for batch inference. See [PromptType][vllm.inputs.PromptType]
997
                for more details about the format of each prompts.
998
999
1000
1001
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1002
            lora_request: LoRA request to use for generation, if any.
1003
1004
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1005
        Returns:
1006
            A list of `ClassificationRequestOutput` objects containing the
1007
1008
            embedding vectors in the same order as the input prompts.
        """
1009
        if "classify" not in self.supported_tasks:
1010
            raise ValueError(
1011
                "Classification API is not supported by this model. "
1012
                "Try converting the model using `--convert classify`.")
1013

1014
1015
1016
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1017
            pooling_params=pooling_params,
1018
1019
1020
            lora_request=lora_request,
            pooling_task="classify",
        )
1021
1022
1023

        return [ClassificationRequestOutput.from_base(item) for item in items]

1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
    def reward(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
                for more details about the format of each prompts.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
            pooling_task="encode",
        )

1063
1064
1065
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1066
1067
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1068
        truncate_prompt_tokens: Optional[int] = None,
1069
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1070
        pooling_params: Optional[PoolingParams] = None,
1071
1072
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1073

1074
        encoded_output: list[PoolingRequestOutput] = self.encode(
1075
            text_1 + text_2,
1076
            truncate_prompt_tokens=truncate_prompt_tokens,
1077
1078
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1079
            pooling_params=pooling_params,
1080
1081
            pooling_task="embed",
        )
1082

1083
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1084
            0:len(text_1)]
1085
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1086
            len(text_1):]
1087
1088
1089
1090

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1091
1092
1093
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1094
1095
1096
1097
1098
1099
1100

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1101
        tokenizer: AnyTokenizer,
1102
1103
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1104
        truncate_prompt_tokens: Optional[int] = None,
1105
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1106
        pooling_params: Optional[PoolingParams] = None,
1107
1108
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1109
        model_config = self.llm_engine.model_config
1110
1111
1112

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
1113
                "Score API is not supported for Mistral tokenizer")
1114

1115
1116
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1117

1118
1119
1120
1121
1122
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        model_config = self.llm_engine.model_config
        pooling_params.verify("score", model_config)
1123
        pooling_params_list = list[PoolingParams]()
1124

1125
        tokenization_kwargs: dict[str, Any] = {}
1126
1127

        _validate_truncation_size(model_config.max_model_len,
1128
                                  truncate_prompt_tokens, tokenization_kwargs)
1129

1130
        prompts = list[PromptType]()
1131

1132
1133
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1134
        model_config = self.llm_engine.model_config
1135

1136
1137
1138
1139
1140
1141
1142
1143
1144
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1145
            if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
1146
1147
1148
1149
1150
1151
1152
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1153
            prompts.append(engine_prompt)
1154
1155

        self._validate_and_add_requests(
1156
            prompts=prompts,
1157
            params=pooling_params_list,
1158
            use_tqdm=use_tqdm,
1159
1160
1161
1162
1163
1164
1165
1166
1167
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1168
1169
    def score(
        self,
1170
1171
1172
1173
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1174
        /,
1175
        *,
1176
        truncate_prompt_tokens: Optional[int] = None,
1177
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1178
        pooling_params: Optional[PoolingParams] = None,
1179
1180
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1181
1182
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1183

1184
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1185
1186
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1187
        The input pairs are used to build a list of prompts for the
1188
1189
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1190
1191
1192
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1193
        appropriate multi-modal models. For multi-modal inputs, ensure the
1194
        prompt structure matches the model's expected input format.
1195
1196

        Args:
1197
1198
1199
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1200
                the `data_2` list.
1201
            data_2: The data to pair with the query to form the input to
1202
                the LLM. Can be text or multi-modal data. See [PromptType]
1203
                [vllm.inputs.PromptType] for more details about the format of
1204
                each prompt.
1205
1206
1207
1208
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1209
            lora_request: LoRA request to use for generation, if any.
1210
1211
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1212
        Returns:
1213
            A list of `ScoringRequestOutput` objects containing the
1214
1215
            generated scores in the same order as the input prompts.
        """
1216
1217
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1218
        if runner_type != "pooling":
1219
1220
1221
1222
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1223

1224
1225
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1226
            raise ValueError("Score API is not supported by this model. "
1227
1228
                             "Try converting the model using "
                             "`--convert embed` or `--convert classify`.")
1229

1230
        if (model_config.is_cross_encoder
1231
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1232
            raise ValueError("Score API is only enabled for num_labels == 1.")
1233
1234
1235
1236

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1237
        tokenizer = self.get_tokenizer()
1238

1239
        if not model_config.is_multimodal_model:
1240
1241
1242
1243
1244

            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
1245
1246
                    raise ValueError("ScoreMultiModalParam is not supported "
                                     f"for {model_config.architecture}")
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1287

1288
        if model_config.is_cross_encoder:
1289
1290
1291
1292
1293
1294
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1295
                pooling_params,
1296
                lora_request)
1297
        else:
1298
1299
            return self._embedding_score(
                tokenizer,
1300
1301
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1302
1303
                truncate_prompt_tokens,
                use_tqdm,
1304
                pooling_params,
1305
                lora_request)
1306

1307
1308
1309
1310
1311
1312
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1313
1314
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1315

1316
1317
1318
1319
1320
1321
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1322
        Args:
1323
1324
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1325
                is forgotten. Level 1 sleep is good for sleeping and waking
1326
1327
1328
1329
1330
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1331
                sleep is good for sleeping and waking up the engine to run a
1332
                different model or update the model, where previous model
1333
                weights are not needed. It reduces CPU memory pressure.
1334
        """
1335
        self.reset_prefix_cache()
1336
1337
        self.llm_engine.sleep(level=level)

1338
    def wake_up(self, tags: Optional[list[str]] = None):
1339
        """
1340
1341
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1342

1343
        Args:
1344
1345
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1346
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1347
                wake_up should be called with all tags (or None) before the
1348
1349
1350
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1351

1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1366
1367
    def _validate_and_add_requests(
        self,
1368
        prompts: Union[PromptType, Sequence[PromptType]],
1369
1370
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1371
        *,
1372
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1373
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1374
        priority: Optional[list[int]] = None,
1375
    ) -> None:
1376
        if isinstance(prompts, (str, dict)):
1377
            # Convert a single prompt to a list.
1378
            prompts = [prompts]
1379

1380
        num_requests = len(prompts)
1381
        if isinstance(params, Sequence) and len(params) != num_requests:
1382
            raise ValueError("The lengths of prompts and params "
1383
                             "must be the same.")
1384
        if isinstance(lora_request,
1385
                      Sequence) and len(lora_request) != num_requests:
1386
1387
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1388

1389
        for sp in params if isinstance(params, Sequence) else (params, ):
1390
1391
1392
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1393

Zhuohan Li's avatar
Zhuohan Li committed
1394
        # Add requests to the engine.
1395
1396
        it = prompts
        if use_tqdm:
1397
1398
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1399

1400
1401
        model_config = self.llm_engine.model_config

1402
        for i, prompt in enumerate(it):
1403
1404
1405
1406
1407
1408
1409
1410

            param = params[i] if isinstance(params, Sequence) else params

            tokenization_kwargs: dict[str, Any] = {}
            _validate_truncation_size(model_config.max_model_len,
                                      param.truncate_prompt_tokens,
                                      tokenization_kwargs)

1411
            self._add_request(
1412
                prompt,
1413
                params[i] if isinstance(params, Sequence) else params,
1414
                tokenization_kwargs=tokenization_kwargs,
1415
1416
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
1417
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1418
            )
1419

1420
    def _add_request(
nunjunj's avatar
nunjunj committed
1421
        self,
1422
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1423
        params: Union[SamplingParams, PoolingParams],
1424
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1425
        lora_request: Optional[LoRARequest] = None,
1426
        priority: int = 0,
1427
1428
    ) -> None:
        request_id = str(next(self.request_counter))
1429
1430
        self.llm_engine.add_request(
            request_id,
1431
            prompt,
1432
1433
            params,
            lora_request=lora_request,
1434
            tokenization_kwargs=tokenization_kwargs,
1435
            priority=priority,
nunjunj's avatar
nunjunj committed
1436
        )
1437

1438
    def _run_engine(
1439
1440
1441
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1442
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1443
1444
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1445
            num_requests = self.llm_engine.get_num_unfinished_requests()
1446
1447
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1448
1449
1450
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1451
1452
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1453
            )
1454

Zhuohan Li's avatar
Zhuohan Li committed
1455
        # Run the engine.
1456
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1457
1458
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1459
1460
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1461
            for output in step_outputs:
1462
                if output.finished:
1463
1464
                    outputs.append(output)
                    if use_tqdm:
1465
1466
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1467
                            n = len(output.outputs)
1468
                            assert output.prompt_token_ids is not None
1469
                            total_in_toks += len(output.prompt_token_ids) * n
1470
1471
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1472
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1473
1474
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1475
1476
1477
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1478
                            pbar.update(n)
1479
1480
                        else:
                            pbar.update(1)
1481
1482
                        if pbar.n == num_requests:
                            pbar.refresh()
1483

1484
1485
        if use_tqdm:
            pbar.close()
1486
1487
1488
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1489
        return sorted(outputs, key=lambda x: int(x.request_id))