llm.py 63.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
from vllm.config import CompilationConfig
17
18
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
19
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
20
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
21
                                         ChatTemplateContentFormatOption,
22
23
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
24
25
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
26
27
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
28
from vllm.entrypoints.utils import _validate_truncation_size
29
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
30
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
31
from vllm.logger import init_logger
32
from vllm.lora.request import LoRARequest
33
34
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
35
36
37
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
38
from vllm.pooling_params import PoolingParams
39
from vllm.prompt_adapter.request import PromptAdapterRequest
40
41
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
42
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
43
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
44
from vllm.usage.usage_lib import UsageContext
45
46
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
47

48
49
logger = init_logger(__name__)

50
51
_R = TypeVar("_R", default=Any)

52
53

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
57
58
59
60
61
62
63
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
64
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
65
66
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
67
68
69
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
70
71
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
72
73
74
75
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
83
        quantization: The method used to quantize the model weights. Currently,
84
            we support "awq", "gptq", and "fp8" (experimental).
85
86
87
88
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
89
90
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
91
92
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
93
94
95
96
97
98
99
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
100
101
102
103
104
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
105
106
107
108
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
109
110
111
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
112
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
113
            When a sequence has context length larger than this, we fall back
114
115
116
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
117
118
119
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
120
        hf_token: The token to use as HTTP bearer authorization for remote files
121
            . If `True`, will use the token generated when running
122
            `huggingface-cli login` (stored in `~/.huggingface`).
123
124
125
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
126
127
128
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
129
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
130
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
131

132
133
134
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
135
    """
136

137
    DEPRECATE_LEGACY: ClassVar[bool] = True
138
139
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

140
141
142
143
144
145
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

146
147
148
149
150
151
152
153
154
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

155
156
157
158
159
160
161
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
162
163
164
    def __init__(
        self,
        model: str,
165
        tokenizer: Optional[str] = None,
166
        tokenizer_mode: str = "auto",
167
        skip_tokenizer_init: bool = False,
168
        trust_remote_code: bool = False,
169
        allowed_local_media_path: str = "",
170
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
171
        dtype: str = "auto",
172
        quantization: Optional[str] = None,
173
        revision: Optional[str] = None,
174
        tokenizer_revision: Optional[str] = None,
175
        seed: Optional[int] = None,
176
        gpu_memory_utilization: float = 0.9,
177
        swap_space: float = 4,
178
        cpu_offload_gb: float = 0,
179
        enforce_eager: Optional[bool] = None,
180
        max_seq_len_to_capture: int = 8192,
181
        disable_custom_all_reduce: bool = False,
182
        disable_async_output_proc: bool = False,
183
        hf_token: Optional[Union[bool, str]] = None,
184
        hf_overrides: Optional[HfOverrides] = None,
185
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
186
187
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
188
        override_pooler_config: Optional[PoolerConfig] = None,
189
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
190
191
        **kwargs,
    ) -> None:
192
193
194
195
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
196
        it defaults to False.
197
198
        '''

199
200
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
201

202
203
204
205
206
207
208
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

209
        if compilation_config is not None:
210
            if isinstance(compilation_config, (int, dict)):
211
212
213
214
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
215
216
217
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
218
        engine_args = EngineArgs(
219
            model=model,
220
            task=task,
221
            tokenizer=tokenizer,
222
            tokenizer_mode=tokenizer_mode,
223
            skip_tokenizer_init=skip_tokenizer_init,
224
            trust_remote_code=trust_remote_code,
225
            allowed_local_media_path=allowed_local_media_path,
226
227
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
228
            quantization=quantization,
229
            revision=revision,
230
            tokenizer_revision=tokenizer_revision,
231
232
233
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
234
            cpu_offload_gb=cpu_offload_gb,
235
            enforce_eager=enforce_eager,
236
            max_seq_len_to_capture=max_seq_len_to_capture,
237
            disable_custom_all_reduce=disable_custom_all_reduce,
238
            disable_async_output_proc=disable_async_output_proc,
239
            hf_token=hf_token,
240
            hf_overrides=hf_overrides,
241
            mm_processor_kwargs=mm_processor_kwargs,
242
            override_pooler_config=override_pooler_config,
243
            compilation_config=compilation_config_instance,
244
245
            **kwargs,
        )
246
247
248
249
250

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
251

252
        self.request_counter = Counter()
253
        self.default_sampling_params: Union[dict[str, Any], None] = None
254

255
256
257
258
259
260
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
261
262

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
263
        tokenizer_group = self.llm_engine.get_tokenizer_group()
264

265
266
267
268
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
269
            tokenizer_group.tokenizer = tokenizer
270
        else:
271
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
272

273
    def get_default_sampling_params(self) -> SamplingParams:
274
275
276
277
278
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
279
280
        return SamplingParams()

281
282
283
284
285
286
287
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
288
        *,
289
        use_tqdm: bool = True,
290
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
291
292
293
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
294
    ) -> list[RequestOutput]:
295
296
        ...

297
    @overload  # LEGACY: single (prompt + optional token ids)
298
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
299
300
301
302
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
303
304
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
305
        use_tqdm: bool = True,
306
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
307
308
309
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
310
    ) -> list[RequestOutput]:
311
312
313
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
314
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
315
316
    def generate(
        self,
317
        prompts: list[str],
318
        sampling_params: Optional[Union[SamplingParams,
319
320
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
321
        use_tqdm: bool = True,
322
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
323
324
325
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
326
    ) -> list[RequestOutput]:
327
328
329
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
330
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
331
332
333
334
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
335
                                        list[SamplingParams]]] = None,
336
        *,
337
        prompt_token_ids: list[int],
338
        use_tqdm: bool = True,
339
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
340
341
342
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
343
    ) -> list[RequestOutput]:
344
345
346
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
347
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
348
349
    def generate(
        self,
350
        prompts: Optional[list[str]] = None,
351
        sampling_params: Optional[Union[SamplingParams,
352
                                        list[SamplingParams]]] = None,
353
        *,
354
        prompt_token_ids: list[list[int]],
355
        use_tqdm: bool = True,
356
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
357
358
359
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
360
    ) -> list[RequestOutput]:
361
362
363
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
364
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
365
366
367
368
    def generate(
        self,
        prompts: None,
        sampling_params: None,
369
        prompt_token_ids: Union[list[int], list[list[int]]],
370
        use_tqdm: bool = True,
371
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
372
373
374
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
375
    ) -> list[RequestOutput]:
376
377
        ...

nunjunj's avatar
nunjunj committed
378
379
380
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
381
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
382
    )
383
384
    def generate(
        self,
385
        prompts: Union[Union[PromptType, Sequence[PromptType]],
386
                       Optional[Union[str, list[str]]]] = None,
387
388
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
389
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
390
        use_tqdm: bool = True,
391
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
392
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
393
        guided_options_request: Optional[Union[LLMGuidedOptions,
394
                                               GuidedDecodingRequest]] = None,
395
396
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
397
398
        """Generates the completions for the input prompts.

399
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
400
401
402
403
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
404
405
406
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
407
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
408
409
410
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
411
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
412
            use_tqdm: Whether to use tqdm to display the progress bar.
413
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
414
            prompt_adapter_request: Prompt Adapter request to use for
415
                generation, if any.
416
417
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
418
419

        Returns:
nunjunj's avatar
nunjunj committed
420
            A list of ``RequestOutput`` objects containing the
421
            generated completions in the same order as the input prompts.
422
423
424
425
426

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
427
        """
428
        runner_type = self.llm_engine.model_config.runner_type
429
        if runner_type not in ["generate", "transcription"]:
430
            messages = [
431
                "LLM.generate() is only supported for (conditional) generation "
432
433
434
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

435
436
437
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
438
                messages.append(
439
440
441
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
442
443

            raise ValueError(" ".join(messages))
444

445
        if prompt_token_ids is not None:
446
            parsed_prompts = self._convert_v1_inputs(
447
                prompts=cast(Optional[Union[str, list[str]]], prompts),
448
449
450
                prompt_token_ids=prompt_token_ids,
            )
        else:
451
452
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
453

454
455
456
457
458
459
460
461
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

462
463
        if sampling_params is None:
            # Use default sampling params.
464
            sampling_params = self.get_default_sampling_params()
465

466
        self._validate_and_add_requests(
467
            prompts=parsed_prompts,
468
469
            params=sampling_params,
            lora_request=lora_request,
470
            prompt_adapter_request=prompt_adapter_request,
471
472
            guided_options=guided_options_request,
            priority=priority)
473

474
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
475
        return self.engine_class.validate_outputs(outputs, RequestOutput)
476

477
    def collective_rpc(self,
478
                       method: Union[str, Callable[..., _R]],
479
                       timeout: Optional[float] = None,
480
481
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
504
505

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
506
507

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
508
        """
509
510
        Run a function directly on the model inside each worker,
        returning the result for each of them.
511
        """
512
513
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
514

515
516
    def beam_search(
        self,
517
        prompts: list[Union[TokensPrompt, TextPrompt]],
518
        params: BeamSearchParams,
519
    ) -> list[BeamSearchOutput]:
520
521
522
523
524
525
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
526
            params: The beam search parameters.
527
        """
528
529
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
530
531
532
533
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
534
535
536
537
538
539
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
540

541
542
543
544
545
546
547
548
549
550
551
552
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
553

554
555
556
557
558
559
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
560
                                            temperature=temperature)
561
        instances: list[BeamSearchInstance] = []
562
563

        for prompt in prompts:
564
565
566
567
568
569
570
571
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

572
573
574
575
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
576
577
            instances.append(
                BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
578
579

        for _ in range(max_tokens):
580
            all_beams: list[BeamSearchSequence] = list(
581
582
583
584
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
585
            instance_start_and_end: list[tuple[int, int]] = list(
586
587
588
589
590
591
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
592
                create_tokens_prompt_from_beam(beam) for beam in all_beams
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
616
                                logprobs=current_beam.logprobs + [logprobs],
617
                                cum_logprob=current_beam.cum_logprob +
618
619
620
621
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
622
623
624
625
626
627
628

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
629
                                      key=sort_beams_key,
630
631
632
633
634
635
636
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
637
                                      key=sort_beams_key,
638
639
640
641
642
643
644
645
646
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
647
648
    def chat(
        self,
649
650
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
651
        sampling_params: Optional[Union[SamplingParams,
652
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
653
654
655
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
656
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
657
        add_generation_prompt: bool = True,
658
        continue_final_message: bool = False,
659
        tools: Optional[list[dict[str, Any]]] = None,
660
        chat_template_kwargs: Optional[dict[str, Any]] = None,
661
662
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
663
        """
664
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
665

666
667
668
669
670
671
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
672
673

        Args:
674
675
676
677
678
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
679
680
681
682
683
684
685
686
687
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
688
689
690
691
692
693
694
695
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

696
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
697
                to each message.
698
            continue_final_message: If True, continues the final message in
699
700
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
701
702
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
703
704
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
705
706
707
708
709

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
710
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
711

712
713
        # Handle multi and single conversations
        if is_list_of(messages, list):
714
715
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
716
                                    messages)
717
        else:
718
            # messages is list[...]
719
            list_of_messages = [
720
                cast(list[ChatCompletionMessageParam], messages)
721
            ]
722

723
        tokenizer = self.get_tokenizer(lora_request)
724
725
726
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
727
            tools,
728
729
            chat_template_content_format,
            tokenizer,
730
            trust_remote_code=model_config.trust_remote_code,
731
732
        )

733
734
735
736
737
738
739
740
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

741
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
742
743

        for msgs in list_of_messages:
744
745
746
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
747
            conversation, mm_data = parse_chat_messages(
748
749
750
751
752
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
753
754

            if isinstance(tokenizer, MistralTokenizer):
755
                prompt_token_ids = apply_mistral_chat_template(
756
757
                    tokenizer,
                    messages=msgs,
758
                    **_chat_template_kwargs,
759
760
                )
            else:
761
                prompt_str = apply_hf_chat_template(
762
                    tokenizer,
763
                    trust_remote_code=model_config.trust_remote_code,
764
                    conversation=conversation,
765
                    **_chat_template_kwargs,
766
                )
767
768
769
770
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
771

772
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
773
774
775
776

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

777
778
779
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

780
            prompts.append(prompt)
781

nunjunj's avatar
nunjunj committed
782
        return self.generate(
783
            prompts,
784
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
785
786
787
788
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

789
790
791
792
793
794
795
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
796
        *,
797
        truncate_prompt_tokens: Optional[int] = None,
798
        use_tqdm: bool = True,
799
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
800
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
801
    ) -> list[PoolingRequestOutput]:
802
803
        ...

804
    @overload  # LEGACY: single (prompt + optional token ids)
805
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
806
807
808
809
810
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
811
        prompt_token_ids: Optional[list[int]] = None,
812
        truncate_prompt_tokens: Optional[int] = None,
813
        use_tqdm: bool = True,
814
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
815
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
816
    ) -> list[PoolingRequestOutput]:
817
        ...
818

819
    @overload  # LEGACY: multi (prompt + optional token ids)
820
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
821
822
    def encode(
        self,
823
        prompts: list[str],
824
        pooling_params: Optional[Union[PoolingParams,
825
                                       Sequence[PoolingParams]]] = None,
826
        prompt_token_ids: Optional[list[list[int]]] = None,
827
        truncate_prompt_tokens: Optional[int] = None,
828
        use_tqdm: bool = True,
829
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
830
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
831
    ) -> list[PoolingRequestOutput]:
832
833
834
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
835
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
836
837
838
839
840
841
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
842
        prompt_token_ids: list[int],
843
        truncate_prompt_tokens: Optional[int] = None,
844
        use_tqdm: bool = True,
845
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
846
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
847
    ) -> list[PoolingRequestOutput]:
848
849
850
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
851
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
852
853
    def encode(
        self,
854
        prompts: Optional[list[str]] = None,
855
856
857
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
858
        prompt_token_ids: list[list[int]],
859
        truncate_prompt_tokens: Optional[int] = None,
860
        use_tqdm: bool = True,
861
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
862
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
863
    ) -> list[PoolingRequestOutput]:
864
865
866
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
867
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
868
869
870
871
    def encode(
        self,
        prompts: None,
        pooling_params: None,
872
        prompt_token_ids: Union[list[int], list[list[int]]],
873
        truncate_prompt_tokens: Optional[int] = None,
874
        use_tqdm: bool = True,
875
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
876
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
877
    ) -> list[PoolingRequestOutput]:
878
879
        ...

nunjunj's avatar
nunjunj committed
880
881
882
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
883
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
884
    )
885
886
    def encode(
        self,
887
        prompts: Union[Union[PromptType, Sequence[PromptType]],
888
                       Optional[Union[str, list[str]]]] = None,
889
890
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
891
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
892
        truncate_prompt_tokens: Optional[int] = None,
893
        use_tqdm: bool = True,
894
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
895
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
896
    ) -> list[PoolingRequestOutput]:
897
898
        """Apply pooling to the hidden states corresponding to the input
        prompts.
899

900
        This class automatically batches the given prompts, considering
901
902
903
904
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
905
906
907
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
908
909
910
911
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
912
            prompt_adapter_request: Prompt Adapter request to use for
913
                generation, if any.
914
915

        Returns:
916
            A list of ``PoolingRequestOutput`` objects containing the
917
            pooled hidden states in the same order as the input prompts.
918
919
920
921
922

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
923
        """
924
925
926
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
927

928
929
930
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
931
                messages.append(
932
933
934
935
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
936
937

            raise ValueError(" ".join(messages))
938

939
        if prompt_token_ids is not None:
940
            parsed_prompts = self._convert_v1_inputs(
941
                prompts=cast(Optional[Union[str, list[str]]], prompts),
942
943
944
                prompt_token_ids=prompt_token_ids,
            )
        else:
945
946
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
947

948
949
950
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
951
952
953
954
955
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
956

957
958
959
960
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

961
        self._validate_and_add_requests(
962
            prompts=parsed_prompts,
963
964
            params=pooling_params,
            lora_request=lora_request,
965
            tokenization_kwargs=tokenization_kwargs,
966
            prompt_adapter_request=prompt_adapter_request,
967
968
        )

969
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
970
        return self.engine_class.validate_outputs(outputs,
971
                                                  PoolingRequestOutput)
972

973
974
975
976
977
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
978
        truncate_prompt_tokens: Optional[int] = None,
979
        use_tqdm: bool = True,
980
981
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
982
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
983
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
984
    ) -> list[EmbeddingRequestOutput]:
985
986
987
988
989
990
991
992
993
994
995
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
996
997
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1012
                            truncate_prompt_tokens=truncate_prompt_tokens,
1013
                            use_tqdm=use_tqdm,
1014
                            pooling_params=pooling_params,
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1026
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1027
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1028
    ) -> list[ClassificationRequestOutput]:
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1060
1061
1062
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1063
1064
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1065
1066
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1067
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1068
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1069
    ) -> list[ScoringRequestOutput]:
1070

1071
        encoded_output: list[PoolingRequestOutput] = self.encode(
1072
            text_1 + text_2,
1073
            truncate_prompt_tokens=truncate_prompt_tokens,
1074
1075
1076
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1077

1078
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1079
            0:len(text_1)]
1080
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1081
            len(text_1):]
1082
1083
1084
1085

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1086
1087
1088
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1089
1090
1091
1092
1093
1094
1095

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1096
        tokenizer: AnyTokenizer,
1097
1098
        text_1: list[str],
        text_2: list[str],
1099
1100
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1101
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1102
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1103
    ) -> list[ScoringRequestOutput]:
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1116
        tokenization_kwargs: dict[str, Any] = {}
1117
1118
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1144
1145
1146
1147
1148
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1149
        *,
1150
1151
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1152
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1153
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1154
    ) -> list[ScoringRequestOutput]:
1155
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1156

1157
1158
1159
1160
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1161
1162
1163
1164
1165
1166
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1167
                case it has to have the same length as the ``text_2`` list
1168
1169
1170
1171
1172
1173
1174
1175
1176
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1177
            A list of ``ScoringRequestOutput`` objects containing the
1178
1179
            generated scores in the same order as the input prompts.
        """
1180
1181
1182
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1183

1184
1185
1186
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1187
                messages.append(
1188
1189
1190
1191
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1192
1193
1194

            raise ValueError(" ".join(messages))

1195
        if self.llm_engine.model_config.task not in ("embed", "score"):
1196
            raise ValueError(
1197
                "Score API is only enabled for `--task embed or --task score`")
1198
1199
1200
1201

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1202
1203
        tokenizer = self.llm_engine.get_tokenizer()

1204
1205
1206
1207
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1208
                                     "supported for scoring")
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1220
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1221
1222
1223
1224

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1225
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1226

1227
        _validate_score_input_lens(input_text_1, input_text_2)
1228

1229
        if self.llm_engine.model_config.is_cross_encoder:
1230
1231
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1232
1233
1234
1235
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1236
1237
1238
1239
1240
1241
1242
1243
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1244

1245
1246
1247
1248
1249
1250
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1251
1252
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1253

1254
1255
1256
1257
1258
1259
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1272
        """
1273
        self.reset_prefix_cache()
1274
1275
        self.llm_engine.sleep(level=level)

1276
    def wake_up(self, tags: Optional[list[str]] = None):
1277
1278
        """
        Wake up the engine from sleep mode. See the :meth:`sleep` method
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
                ("weights", "kv_cache",). If None, all memory is reallocated.
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1289

1290
1291
    # LEGACY
    def _convert_v1_inputs(
1292
        self,
1293
1294
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1295
1296
    ):
        # skip_tokenizer_init is now checked in engine
1297

1298
1299
1300
1301
1302
1303
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1304

1305
        num_requests = None
1306
1307
        if prompts is not None:
            num_requests = len(prompts)
1308
1309
1310
1311
1312
1313
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1314
            num_requests = len(prompt_token_ids)
1315
1316
1317
1318
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1319
        parsed_prompts: list[PromptType] = []
1320
        for i in range(num_requests):
1321
            item: PromptType
1322

1323
            if prompts is not None:
1324
1325
1326
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1327
            else:
1328
                raise AssertionError
1329

1330
            parsed_prompts.append(item)
1331

1332
        return parsed_prompts
1333
1334
1335

    def _validate_and_add_requests(
        self,
1336
        prompts: Union[PromptType, Sequence[PromptType]],
1337
1338
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1339
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1340
        prompt_adapter_request: Optional[PromptAdapterRequest],
1341
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1342
        guided_options: Optional[GuidedDecodingRequest] = None,
1343
        priority: Optional[list[int]] = None,
1344
    ) -> None:
1345
1346
1347
1348
1349
1350
1351
1352
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1353
        if isinstance(prompts, (str, dict)):
1354
            # Convert a single prompt to a list.
1355
            prompts = [prompts]
1356

1357
        num_requests = len(prompts)
1358
1359
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1360
                             "must be the same.")
1361
1362
1363
1364
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1365

1366
1367
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1368
                self._add_guided_params(sp, guided_options)
1369
1370
1371

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1372

Zhuohan Li's avatar
Zhuohan Li committed
1373
        # Add requests to the engine.
1374
        for i, prompt in enumerate(prompts):
1375
            self._add_request(
1376
                prompt,
1377
                params[i] if isinstance(params, Sequence) else params,
1378
                tokenization_kwargs=tokenization_kwargs,
1379
1380
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1381
                prompt_adapter_request=prompt_adapter_request,
1382
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1383
            )
1384

1385
    def _add_request(
nunjunj's avatar
nunjunj committed
1386
        self,
1387
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1388
        params: Union[SamplingParams, PoolingParams],
1389
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1390
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1391
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1392
        priority: int = 0,
1393
1394
    ) -> None:
        request_id = str(next(self.request_counter))
1395
1396
        self.llm_engine.add_request(
            request_id,
1397
            prompt,
1398
1399
            params,
            lora_request=lora_request,
1400
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1401
            prompt_adapter_request=prompt_adapter_request,
1402
            priority=priority,
nunjunj's avatar
nunjunj committed
1403
        )
1404

1405
    def _add_guided_params(
1406
1407
1408
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1409
1410
1411
1412
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1413
            raise ValueError("Cannot set both guided_options_request and "
1414
1415
1416
1417
1418
1419
1420
1421
1422
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1423
1424
1425
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1426
1427
        return params

1428
    def _run_engine(
1429
            self, *, use_tqdm: bool
1430
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1431
1432
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1433
            num_requests = self.llm_engine.get_num_unfinished_requests()
1434
1435
1436
1437
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1438
1439
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1440
            )
1441

Zhuohan Li's avatar
Zhuohan Li committed
1442
        # Run the engine.
1443
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1444
1445
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1446
1447
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1448
            for output in step_outputs:
1449
                if output.finished:
1450
1451
                    outputs.append(output)
                    if use_tqdm:
1452
1453
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1454
                            n = len(output.outputs)
1455
                            assert output.prompt_token_ids is not None
1456
                            total_in_toks += len(output.prompt_token_ids) * n
1457
1458
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1459
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1460
1461
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1462
1463
1464
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1465
                            pbar.update(n)
1466
1467
                        else:
                            pbar.update(1)
1468

1469
1470
        if use_tqdm:
            pbar.close()
1471
1472
1473
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1474
        return sorted(outputs, key=lambda x: int(x.request_id))