llm.py 64 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
17
18
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
19
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
20
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
21
                                         ChatTemplateContentFormatOption,
22
23
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
24
25
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
26
27
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
28
from vllm.entrypoints.utils import _validate_truncation_size
29
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
30
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
31
from vllm.logger import init_logger
32
from vllm.lora.request import LoRARequest
33
34
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
35
from vllm.model_executor.layers.quantization import QuantizationMethods
36
37
38
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
39
from vllm.pooling_params import PoolingParams
40
from vllm.prompt_adapter.request import PromptAdapterRequest
41
42
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
43
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
44
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
45
from vllm.usage.usage_lib import UsageContext
46
47
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
48

49
50
logger = init_logger(__name__)

51
52
_R = TypeVar("_R", default=Any)

53
54

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
58
59
60
61
62
63
64
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
65
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
66
67
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
68
69
70
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
71
72
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
73
74
75
76
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
82
83
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
84
        quantization: The method used to quantize the model weights. Currently,
85
            we support "awq", "gptq", and "fp8" (experimental).
86
87
88
89
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
90
91
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
92
93
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
94
95
96
97
98
99
100
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
101
102
103
104
105
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
106
107
108
109
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
110
111
112
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
113
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
114
            When a sequence has context length larger than this, we fall back
115
116
117
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
118
119
120
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
121
        hf_token: The token to use as HTTP bearer authorization for remote files
122
            . If `True`, will use the token generated when running
123
            `huggingface-cli login` (stored in `~/.huggingface`).
124
125
126
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
127
128
129
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
130
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
131
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
132

133
134
135
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
136
    """
137

138
    DEPRECATE_LEGACY: ClassVar[bool] = True
139
140
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

141
142
143
144
145
146
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

147
148
149
150
151
152
153
154
155
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

156
157
158
159
160
161
162
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
163
164
165
    def __init__(
        self,
        model: str,
166
        tokenizer: Optional[str] = None,
167
        tokenizer_mode: TokenizerMode = "auto",
168
        skip_tokenizer_init: bool = False,
169
        trust_remote_code: bool = False,
170
        allowed_local_media_path: str = "",
171
        tensor_parallel_size: int = 1,
172
173
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
174
        revision: Optional[str] = None,
175
        tokenizer_revision: Optional[str] = None,
176
        seed: Optional[int] = None,
177
        gpu_memory_utilization: float = 0.9,
178
        swap_space: float = 4,
179
        cpu_offload_gb: float = 0,
180
        enforce_eager: bool = False,
181
        max_seq_len_to_capture: int = 8192,
182
        disable_custom_all_reduce: bool = False,
183
        disable_async_output_proc: bool = False,
184
        hf_token: Optional[Union[bool, str]] = None,
185
        hf_overrides: Optional[HfOverrides] = None,
186
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
187
188
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
189
        override_pooler_config: Optional[PoolerConfig] = None,
190
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
191
192
        **kwargs,
    ) -> None:
193
        """LLM constructor."""
194

195
196
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
197

198
199
200
201
202
203
204
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

205
        if compilation_config is not None:
206
            if isinstance(compilation_config, (int, dict)):
207
208
209
210
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
211
212
213
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
214
        engine_args = EngineArgs(
215
            model=model,
216
            task=task,
217
            tokenizer=tokenizer,
218
            tokenizer_mode=tokenizer_mode,
219
            skip_tokenizer_init=skip_tokenizer_init,
220
            trust_remote_code=trust_remote_code,
221
            allowed_local_media_path=allowed_local_media_path,
222
223
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
224
            quantization=quantization,
225
            revision=revision,
226
            tokenizer_revision=tokenizer_revision,
227
228
229
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
230
            cpu_offload_gb=cpu_offload_gb,
231
            enforce_eager=enforce_eager,
232
            max_seq_len_to_capture=max_seq_len_to_capture,
233
            disable_custom_all_reduce=disable_custom_all_reduce,
234
            disable_async_output_proc=disable_async_output_proc,
235
            hf_token=hf_token,
236
            hf_overrides=hf_overrides,
237
            mm_processor_kwargs=mm_processor_kwargs,
238
            override_pooler_config=override_pooler_config,
239
            compilation_config=compilation_config_instance,
240
241
            **kwargs,
        )
242
243
244
245
246

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
247

248
        self.request_counter = Counter()
249
        self.default_sampling_params: Union[dict[str, Any], None] = None
250

251
252
253
254
255
256
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
257
258

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
259
        tokenizer_group = self.llm_engine.get_tokenizer_group()
260

261
262
263
264
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
265
            tokenizer_group.tokenizer = tokenizer
266
        else:
267
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
268

269
    def get_default_sampling_params(self) -> SamplingParams:
270
271
272
273
274
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
275
276
        return SamplingParams()

277
278
279
280
281
282
283
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
284
        *,
285
        use_tqdm: bool = True,
286
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
287
288
289
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
290
    ) -> list[RequestOutput]:
291
292
        ...

293
    @overload  # LEGACY: single (prompt + optional token ids)
294
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
295
296
297
298
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
299
300
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
301
        use_tqdm: bool = True,
302
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
303
304
305
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
306
    ) -> list[RequestOutput]:
307
308
309
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
310
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
311
312
    def generate(
        self,
313
        prompts: list[str],
314
        sampling_params: Optional[Union[SamplingParams,
315
316
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
317
        use_tqdm: bool = True,
318
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
319
320
321
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
322
    ) -> list[RequestOutput]:
323
324
325
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
326
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
327
328
329
330
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
331
                                        list[SamplingParams]]] = None,
332
        *,
333
        prompt_token_ids: list[int],
334
        use_tqdm: bool = True,
335
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
336
337
338
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
339
    ) -> list[RequestOutput]:
340
341
342
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
343
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
344
345
    def generate(
        self,
346
        prompts: Optional[list[str]] = None,
347
        sampling_params: Optional[Union[SamplingParams,
348
                                        list[SamplingParams]]] = None,
349
        *,
350
        prompt_token_ids: list[list[int]],
351
        use_tqdm: bool = True,
352
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
353
354
355
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
356
    ) -> list[RequestOutput]:
357
358
359
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
360
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
361
362
363
364
    def generate(
        self,
        prompts: None,
        sampling_params: None,
365
        prompt_token_ids: Union[list[int], list[list[int]]],
366
        use_tqdm: bool = True,
367
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
368
369
370
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
371
    ) -> list[RequestOutput]:
372
373
        ...

nunjunj's avatar
nunjunj committed
374
375
376
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
377
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
378
    )
379
380
    def generate(
        self,
381
        prompts: Union[Union[PromptType, Sequence[PromptType]],
382
                       Optional[Union[str, list[str]]]] = None,
383
384
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
385
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
386
        use_tqdm: bool = True,
387
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
388
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
389
        guided_options_request: Optional[Union[LLMGuidedOptions,
390
                                               GuidedDecodingRequest]] = None,
391
392
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
393
394
        """Generates the completions for the input prompts.

395
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
396
397
398
399
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
400
401
402
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
403
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
404
405
406
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
407
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
408
            use_tqdm: Whether to use tqdm to display the progress bar.
409
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
410
            prompt_adapter_request: Prompt Adapter request to use for
411
                generation, if any.
412
413
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
414
415

        Returns:
nunjunj's avatar
nunjunj committed
416
            A list of ``RequestOutput`` objects containing the
417
            generated completions in the same order as the input prompts.
418
419
420
421
422

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
423
        """
424
        runner_type = self.llm_engine.model_config.runner_type
425
        if runner_type not in ["generate", "transcription"]:
426
            messages = [
427
                "LLM.generate() is only supported for (conditional) generation "
428
429
430
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

431
432
433
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
434
                messages.append(
435
436
437
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
438
439

            raise ValueError(" ".join(messages))
440

441
        if prompt_token_ids is not None:
442
            parsed_prompts = self._convert_v1_inputs(
443
                prompts=cast(Optional[Union[str, list[str]]], prompts),
444
445
446
                prompt_token_ids=prompt_token_ids,
            )
        else:
447
448
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
449

450
451
452
453
454
455
456
457
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

458
459
        if sampling_params is None:
            # Use default sampling params.
460
            sampling_params = self.get_default_sampling_params()
461

462
        self._validate_and_add_requests(
463
            prompts=parsed_prompts,
464
            params=sampling_params,
465
            use_tqdm=use_tqdm,
466
            lora_request=lora_request,
467
            prompt_adapter_request=prompt_adapter_request,
468
            guided_options=guided_options_request,
469
470
            priority=priority,
        )
471

472
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
473
        return self.engine_class.validate_outputs(outputs, RequestOutput)
474

475
    def collective_rpc(self,
476
                       method: Union[str, Callable[..., _R]],
477
                       timeout: Optional[float] = None,
478
479
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
502
503

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
504
505

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
506
        """
507
508
        Run a function directly on the model inside each worker,
        returning the result for each of them.
509
        """
510
511
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
512

513
514
    def beam_search(
        self,
515
        prompts: list[Union[TokensPrompt, TextPrompt]],
516
        params: BeamSearchParams,
517
    ) -> list[BeamSearchOutput]:
518
519
520
521
522
523
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
524
            params: The beam search parameters.
525
        """
526
527
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
528
529
530
531
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
532
533
534
535
536
537
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
538

539
540
541
542
543
544
545
546
547
548
549
550
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
551

552
553
554
555
556
557
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
558
                                            temperature=temperature)
559
        instances: list[BeamSearchInstance] = []
560
561

        for prompt in prompts:
562
563
564
565
566
567
568
569
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

570
571
572
573
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
574
575
            instances.append(
                BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
576
577

        for _ in range(max_tokens):
578
            all_beams: list[BeamSearchSequence] = list(
579
580
581
582
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
583
            instance_start_and_end: list[tuple[int, int]] = list(
584
585
586
587
588
589
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
590
                create_tokens_prompt_from_beam(beam) for beam in all_beams
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
614
                                logprobs=current_beam.logprobs + [logprobs],
615
                                cum_logprob=current_beam.cum_logprob +
616
617
618
619
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
620
621
622
623
624
625
626

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
627
                                      key=sort_beams_key,
628
629
630
631
632
633
634
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
635
                                      key=sort_beams_key,
636
637
638
639
640
641
642
643
644
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
645
646
    def chat(
        self,
647
648
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
649
        sampling_params: Optional[Union[SamplingParams,
650
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
651
652
653
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
654
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
655
        add_generation_prompt: bool = True,
656
        continue_final_message: bool = False,
657
        tools: Optional[list[dict[str, Any]]] = None,
658
        chat_template_kwargs: Optional[dict[str, Any]] = None,
659
660
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
661
        """
662
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
663

664
665
666
667
668
669
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
670
671

        Args:
672
673
674
675
676
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
677
678
679
680
681
682
683
684
685
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
686
687
688
689
690
691
692
693
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

694
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
695
                to each message.
696
            continue_final_message: If True, continues the final message in
697
698
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
699
700
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
701
702
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
703
704
705
706
707

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
708
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
709

710
711
        # Handle multi and single conversations
        if is_list_of(messages, list):
712
713
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
714
                                    messages)
715
        else:
716
            # messages is list[...]
717
            list_of_messages = [
718
                cast(list[ChatCompletionMessageParam], messages)
719
            ]
720

721
        tokenizer = self.get_tokenizer(lora_request)
722
723
724
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
725
            tools,
726
727
            chat_template_content_format,
            tokenizer,
728
            trust_remote_code=model_config.trust_remote_code,
729
730
        )

731
732
733
734
735
736
737
738
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

739
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
740
741

        for msgs in list_of_messages:
742
743
744
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
745
            conversation, mm_data = parse_chat_messages(
746
747
748
749
750
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
751
752

            if isinstance(tokenizer, MistralTokenizer):
753
                prompt_token_ids = apply_mistral_chat_template(
754
755
                    tokenizer,
                    messages=msgs,
756
                    **_chat_template_kwargs,
757
758
                )
            else:
759
                prompt_str = apply_hf_chat_template(
760
                    tokenizer,
761
                    trust_remote_code=model_config.trust_remote_code,
762
                    conversation=conversation,
763
                    **_chat_template_kwargs,
764
                )
765
766
767
768
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
769

770
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
771
772
773
774

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

775
776
777
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

778
            prompts.append(prompt)
779

nunjunj's avatar
nunjunj committed
780
        return self.generate(
781
            prompts,
782
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
783
784
785
786
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

787
788
789
790
791
792
793
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
794
        *,
795
        truncate_prompt_tokens: Optional[int] = None,
796
        use_tqdm: bool = True,
797
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
798
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
799
    ) -> list[PoolingRequestOutput]:
800
801
        ...

802
    @overload  # LEGACY: single (prompt + optional token ids)
803
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
804
805
806
807
808
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
809
        prompt_token_ids: Optional[list[int]] = None,
810
        truncate_prompt_tokens: Optional[int] = None,
811
        use_tqdm: bool = True,
812
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
813
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
814
    ) -> list[PoolingRequestOutput]:
815
        ...
816

817
    @overload  # LEGACY: multi (prompt + optional token ids)
818
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
819
820
    def encode(
        self,
821
        prompts: list[str],
822
        pooling_params: Optional[Union[PoolingParams,
823
                                       Sequence[PoolingParams]]] = None,
824
        prompt_token_ids: Optional[list[list[int]]] = None,
825
        truncate_prompt_tokens: Optional[int] = None,
826
        use_tqdm: bool = True,
827
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
828
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
829
    ) -> list[PoolingRequestOutput]:
830
831
832
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
833
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
834
835
836
837
838
839
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
840
        prompt_token_ids: list[int],
841
        truncate_prompt_tokens: Optional[int] = None,
842
        use_tqdm: bool = True,
843
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
844
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
845
    ) -> list[PoolingRequestOutput]:
846
847
848
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
849
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
850
851
    def encode(
        self,
852
        prompts: Optional[list[str]] = None,
853
854
855
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
856
        prompt_token_ids: list[list[int]],
857
        truncate_prompt_tokens: Optional[int] = None,
858
        use_tqdm: bool = True,
859
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
860
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
861
    ) -> list[PoolingRequestOutput]:
862
863
864
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
865
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
866
867
868
869
    def encode(
        self,
        prompts: None,
        pooling_params: None,
870
        prompt_token_ids: Union[list[int], list[list[int]]],
871
        truncate_prompt_tokens: Optional[int] = None,
872
        use_tqdm: bool = True,
873
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
874
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
875
    ) -> list[PoolingRequestOutput]:
876
877
        ...

nunjunj's avatar
nunjunj committed
878
879
880
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
881
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
882
    )
883
884
    def encode(
        self,
885
        prompts: Union[Union[PromptType, Sequence[PromptType]],
886
                       Optional[Union[str, list[str]]]] = None,
887
888
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
889
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
890
        truncate_prompt_tokens: Optional[int] = None,
891
        use_tqdm: bool = True,
892
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
893
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
894
    ) -> list[PoolingRequestOutput]:
895
896
        """Apply pooling to the hidden states corresponding to the input
        prompts.
897

898
        This class automatically batches the given prompts, considering
899
900
901
902
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
903
904
905
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
906
907
908
909
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
910
            prompt_adapter_request: Prompt Adapter request to use for
911
                generation, if any.
912
913

        Returns:
914
            A list of ``PoolingRequestOutput`` objects containing the
915
            pooled hidden states in the same order as the input prompts.
916
917
918
919
920

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
921
        """
922
923
924
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
925

926
927
928
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
929
                messages.append(
930
931
932
933
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
934
935

            raise ValueError(" ".join(messages))
936

937
        if prompt_token_ids is not None:
938
            parsed_prompts = self._convert_v1_inputs(
939
                prompts=cast(Optional[Union[str, list[str]]], prompts),
940
941
942
                prompt_token_ids=prompt_token_ids,
            )
        else:
943
944
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
945

946
947
948
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
949
950
951
952
953
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
954

955
956
957
958
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

959
        self._validate_and_add_requests(
960
            prompts=parsed_prompts,
961
            params=pooling_params,
962
            use_tqdm=use_tqdm,
963
            lora_request=lora_request,
964
            tokenization_kwargs=tokenization_kwargs,
965
            prompt_adapter_request=prompt_adapter_request,
966
967
        )

968
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
969
        return self.engine_class.validate_outputs(outputs,
970
                                                  PoolingRequestOutput)
971

972
973
974
975
976
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
977
        truncate_prompt_tokens: Optional[int] = None,
978
        use_tqdm: bool = True,
979
980
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
981
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
982
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
983
    ) -> list[EmbeddingRequestOutput]:
984
985
986
987
988
989
990
991
992
993
994
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
995
996
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1011
                            truncate_prompt_tokens=truncate_prompt_tokens,
1012
                            use_tqdm=use_tqdm,
1013
                            pooling_params=pooling_params,
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1025
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1026
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1027
    ) -> list[ClassificationRequestOutput]:
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1059
1060
1061
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1062
1063
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1064
1065
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1066
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1067
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1068
    ) -> list[ScoringRequestOutput]:
1069

1070
        encoded_output: list[PoolingRequestOutput] = self.encode(
1071
            text_1 + text_2,
1072
            truncate_prompt_tokens=truncate_prompt_tokens,
1073
1074
1075
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1076

1077
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1078
            0:len(text_1)]
1079
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1080
            len(text_1):]
1081
1082
1083
1084

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1085
1086
1087
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1088
1089
1090
1091
1092
1093
1094

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1095
        tokenizer: AnyTokenizer,
1096
1097
        text_1: list[str],
        text_2: list[str],
1098
1099
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1100
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1101
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1102
    ) -> list[ScoringRequestOutput]:
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1115
        tokenization_kwargs: dict[str, Any] = {}
1116
1117
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1133
            use_tqdm=use_tqdm,
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1144
1145
1146
1147
1148
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1149
        *,
1150
1151
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1152
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1153
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1154
    ) -> list[ScoringRequestOutput]:
1155
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1156

1157
1158
1159
1160
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1161
1162
1163
1164
1165
1166
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1167
                case it has to have the same length as the ``text_2`` list
1168
1169
1170
1171
1172
1173
1174
1175
1176
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1177
            A list of ``ScoringRequestOutput`` objects containing the
1178
1179
            generated scores in the same order as the input prompts.
        """
1180
1181
1182
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1183

1184
1185
1186
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1187
                messages.append(
1188
1189
1190
1191
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1192
1193
1194

            raise ValueError(" ".join(messages))

1195
        if self.llm_engine.model_config.task not in ("embed", "score"):
1196
            raise ValueError(
1197
                "Score API is only enabled for `--task embed or --task score`")
1198
1199
1200
1201

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1202
1203
        tokenizer = self.llm_engine.get_tokenizer()

1204
1205
1206
1207
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1208
                                     "supported for scoring")
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1220
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1221
1222
1223
1224

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1225
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1226

1227
        _validate_score_input_lens(input_text_1, input_text_2)
1228

1229
        if self.llm_engine.model_config.is_cross_encoder:
1230
1231
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1232
1233
1234
1235
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1236
1237
1238
1239
1240
1241
1242
1243
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1244

1245
1246
1247
1248
1249
1250
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1251
1252
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1253

1254
1255
1256
1257
1258
1259
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1272
        """
1273
        self.reset_prefix_cache()
1274
1275
        self.llm_engine.sleep(level=level)

1276
    def wake_up(self, tags: Optional[list[str]] = None):
1277
1278
        """
        Wake up the engine from sleep mode. See the :meth:`sleep` method
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
                ("weights", "kv_cache",). If None, all memory is reallocated.
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1289

1290
1291
    # LEGACY
    def _convert_v1_inputs(
1292
        self,
1293
1294
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1295
1296
    ):
        # skip_tokenizer_init is now checked in engine
1297

1298
1299
1300
1301
1302
1303
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1304

1305
        num_requests = None
1306
1307
        if prompts is not None:
            num_requests = len(prompts)
1308
1309
1310
1311
1312
1313
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1314
            num_requests = len(prompt_token_ids)
1315
1316
1317
1318
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1319
        parsed_prompts: list[PromptType] = []
1320
        for i in range(num_requests):
1321
            item: PromptType
1322

1323
            if prompts is not None:
1324
1325
1326
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1327
            else:
1328
                raise AssertionError
1329

1330
            parsed_prompts.append(item)
1331

1332
        return parsed_prompts
1333
1334
1335

    def _validate_and_add_requests(
        self,
1336
        prompts: Union[PromptType, Sequence[PromptType]],
1337
1338
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1339
1340
        *,
        use_tqdm: bool,
1341
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1342
        prompt_adapter_request: Optional[PromptAdapterRequest],
1343
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1344
        guided_options: Optional[GuidedDecodingRequest] = None,
1345
        priority: Optional[list[int]] = None,
1346
    ) -> None:
1347
1348
1349
1350
1351
1352
1353
1354
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1355
        if isinstance(prompts, (str, dict)):
1356
            # Convert a single prompt to a list.
1357
            prompts = [prompts]
1358

1359
        num_requests = len(prompts)
1360
1361
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1362
                             "must be the same.")
1363
1364
1365
1366
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1367

1368
1369
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1370
                self._add_guided_params(sp, guided_options)
1371
1372
1373

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1374

Zhuohan Li's avatar
Zhuohan Li committed
1375
        # Add requests to the engine.
1376
1377
1378
1379
1380
        it = prompts
        if use_tqdm:
            it = tqdm(it, desc="Adding requests")

        for i, prompt in enumerate(it):
1381
            self._add_request(
1382
                prompt,
1383
                params[i] if isinstance(params, Sequence) else params,
1384
                tokenization_kwargs=tokenization_kwargs,
1385
1386
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1387
                prompt_adapter_request=prompt_adapter_request,
1388
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1389
            )
1390

1391
    def _add_request(
nunjunj's avatar
nunjunj committed
1392
        self,
1393
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1394
        params: Union[SamplingParams, PoolingParams],
1395
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1396
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1397
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1398
        priority: int = 0,
1399
1400
    ) -> None:
        request_id = str(next(self.request_counter))
1401
1402
        self.llm_engine.add_request(
            request_id,
1403
            prompt,
1404
1405
            params,
            lora_request=lora_request,
1406
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1407
            prompt_adapter_request=prompt_adapter_request,
1408
            priority=priority,
nunjunj's avatar
nunjunj committed
1409
        )
1410

1411
    def _add_guided_params(
1412
1413
1414
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1415
1416
1417
1418
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1419
            raise ValueError("Cannot set both guided_options_request and "
1420
1421
1422
1423
1424
1425
1426
1427
1428
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1429
1430
1431
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1432
1433
        return params

1434
    def _run_engine(
1435
            self, *, use_tqdm: bool
1436
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1437
1438
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1439
            num_requests = self.llm_engine.get_num_unfinished_requests()
1440
1441
1442
1443
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1444
1445
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1446
            )
1447

Zhuohan Li's avatar
Zhuohan Li committed
1448
        # Run the engine.
1449
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1450
1451
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1452
1453
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1454
            for output in step_outputs:
1455
                if output.finished:
1456
1457
                    outputs.append(output)
                    if use_tqdm:
1458
1459
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1460
                            n = len(output.outputs)
1461
                            assert output.prompt_token_ids is not None
1462
                            total_in_toks += len(output.prompt_token_ids) * n
1463
1464
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1465
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1466
1467
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1468
1469
1470
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1471
                            pbar.update(n)
1472
1473
                        else:
                            pbar.update(1)
1474

1475
1476
        if use_tqdm:
            pbar.close()
1477
1478
1479
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1480
        return sorted(outputs, key=lambda x: int(x.request_id))