llm.py 64 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
17
18
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
19
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
20
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
21
                                         ChatTemplateContentFormatOption,
22
23
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
24
25
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
26
27
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
28
from vllm.entrypoints.utils import _validate_truncation_size
29
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
30
from vllm.inputs.parse import parse_and_batch_prompt
31
from vllm.logger import init_logger
32
from vllm.lora.request import LoRARequest
33
34
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
35
from vllm.model_executor.layers.quantization import QuantizationMethods
36
37
38
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
39
from vllm.pooling_params import PoolingParams
40
from vllm.prompt_adapter.request import PromptAdapterRequest
41
42
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
43
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
44
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
45
from vllm.usage.usage_lib import UsageContext
46
47
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
48

49
50
logger = init_logger(__name__)

51
52
_R = TypeVar("_R", default=Any)

53
54

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
58
59
60
61
62
63
64
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
65
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
66
67
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
68
69
70
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
71
72
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
73
74
75
76
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
82
83
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
84
        quantization: The method used to quantize the model weights. Currently,
85
            we support "awq", "gptq", and "fp8" (experimental).
86
87
88
89
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
90
91
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
92
93
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
94
95
96
97
98
99
100
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
101
102
103
104
105
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
106
107
108
109
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
110
111
112
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
113
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
114
            When a sequence has context length larger than this, we fall back
115
116
117
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
118
        disable_custom_all_reduce: See {class}`~vllm.config.ParallelConfig`
119
120
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
121
        hf_token: The token to use as HTTP bearer authorization for remote files
122
            . If `True`, will use the token generated when running
123
            `huggingface-cli login` (stored in `~/.huggingface`).
124
125
126
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
127
128
129
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
130
131
        **kwargs: Arguments for {class}`~vllm.EngineArgs`. (See
            {ref}`engine-args`)
nunjunj's avatar
nunjunj committed
132

133
134
135
136
    :::{note}
    This class is intended to be used for offline inference. For online
    serving, use the {class}`~vllm.AsyncLLMEngine` class instead.
    :::
137
    """
138

139
    DEPRECATE_LEGACY: ClassVar[bool] = True
140
141
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

142
143
144
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
145
    {meth}`LLM.__init__`.
146
147
    """

148
149
150
151
152
153
154
155
156
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

157
158
159
160
161
162
163
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
164
165
166
    def __init__(
        self,
        model: str,
167
        tokenizer: Optional[str] = None,
168
        tokenizer_mode: TokenizerMode = "auto",
169
        skip_tokenizer_init: bool = False,
170
        trust_remote_code: bool = False,
171
        allowed_local_media_path: str = "",
172
        tensor_parallel_size: int = 1,
173
174
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
175
        revision: Optional[str] = None,
176
        tokenizer_revision: Optional[str] = None,
177
        seed: Optional[int] = None,
178
        gpu_memory_utilization: float = 0.9,
179
        swap_space: float = 4,
180
        cpu_offload_gb: float = 0,
181
        enforce_eager: bool = False,
182
        max_seq_len_to_capture: int = 8192,
183
        disable_custom_all_reduce: bool = False,
184
        disable_async_output_proc: bool = False,
185
        hf_token: Optional[Union[bool, str]] = None,
186
        hf_overrides: Optional[HfOverrides] = None,
187
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
188
189
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
190
        override_pooler_config: Optional[PoolerConfig] = None,
191
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
192
193
        **kwargs,
    ) -> None:
194
        """LLM constructor."""
195

196
197
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
198

199
200
201
202
203
204
205
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

206
        if compilation_config is not None:
207
            if isinstance(compilation_config, (int, dict)):
208
209
210
211
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
212
213
214
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
215
        engine_args = EngineArgs(
216
            model=model,
217
            task=task,
218
            tokenizer=tokenizer,
219
            tokenizer_mode=tokenizer_mode,
220
            skip_tokenizer_init=skip_tokenizer_init,
221
            trust_remote_code=trust_remote_code,
222
            allowed_local_media_path=allowed_local_media_path,
223
224
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
225
            quantization=quantization,
226
            revision=revision,
227
            tokenizer_revision=tokenizer_revision,
228
229
230
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
231
            cpu_offload_gb=cpu_offload_gb,
232
            enforce_eager=enforce_eager,
233
            max_seq_len_to_capture=max_seq_len_to_capture,
234
            disable_custom_all_reduce=disable_custom_all_reduce,
235
            disable_async_output_proc=disable_async_output_proc,
236
            hf_token=hf_token,
237
            hf_overrides=hf_overrides,
238
            mm_processor_kwargs=mm_processor_kwargs,
239
            override_pooler_config=override_pooler_config,
240
            compilation_config=compilation_config_instance,
241
242
            **kwargs,
        )
243
244
245
246
247

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
248

249
        self.request_counter = Counter()
250
        self.default_sampling_params: Union[dict[str, Any], None] = None
251

252
253
254
255
256
257
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
258
259

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
260
        tokenizer_group = self.llm_engine.get_tokenizer_group()
261

262
263
264
265
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
266
            tokenizer_group.tokenizer = tokenizer
267
        else:
268
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
269

270
    def get_default_sampling_params(self) -> SamplingParams:
271
272
273
274
275
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
276
277
        return SamplingParams()

278
279
280
281
282
283
284
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
285
        *,
286
        use_tqdm: bool = True,
287
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
288
289
290
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
291
    ) -> list[RequestOutput]:
292
293
        ...

294
    @overload  # LEGACY: single (prompt + optional token ids)
295
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
296
297
298
299
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
300
301
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
302
        use_tqdm: bool = True,
303
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
304
305
306
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
307
    ) -> list[RequestOutput]:
308
309
310
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
311
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
312
313
    def generate(
        self,
314
        prompts: list[str],
315
        sampling_params: Optional[Union[SamplingParams,
316
317
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
318
        use_tqdm: bool = True,
319
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
320
321
322
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
323
    ) -> list[RequestOutput]:
324
325
326
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
327
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
328
329
330
331
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
332
                                        list[SamplingParams]]] = None,
333
        *,
334
        prompt_token_ids: list[int],
335
        use_tqdm: bool = True,
336
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
337
338
339
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
340
    ) -> list[RequestOutput]:
341
342
343
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
344
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
345
346
    def generate(
        self,
347
        prompts: Optional[list[str]] = None,
348
        sampling_params: Optional[Union[SamplingParams,
349
                                        list[SamplingParams]]] = None,
350
        *,
351
        prompt_token_ids: list[list[int]],
352
        use_tqdm: bool = True,
353
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
354
355
356
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
357
    ) -> list[RequestOutput]:
358
359
360
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
361
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
362
363
364
365
    def generate(
        self,
        prompts: None,
        sampling_params: None,
366
        prompt_token_ids: Union[list[int], list[list[int]]],
367
        use_tqdm: bool = True,
368
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
369
370
371
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
372
    ) -> list[RequestOutput]:
373
374
        ...

nunjunj's avatar
nunjunj committed
375
376
377
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
378
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
379
    )
380
381
    def generate(
        self,
382
        prompts: Union[Union[PromptType, Sequence[PromptType]],
383
                       Optional[Union[str, list[str]]]] = None,
384
385
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
386
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
387
        use_tqdm: bool = True,
388
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
389
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
390
        guided_options_request: Optional[Union[LLMGuidedOptions,
391
                                               GuidedDecodingRequest]] = None,
392
393
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
394
395
        """Generates the completions for the input prompts.

396
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
397
398
399
400
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
401
            prompts: The prompts to the LLM. You may pass a sequence of prompts
402
                for batch inference. See {class}`~vllm.inputs.PromptType`
403
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
404
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
405
406
407
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
408
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
409
            use_tqdm: Whether to use tqdm to display the progress bar.
410
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
411
            prompt_adapter_request: Prompt Adapter request to use for
412
                generation, if any.
413
414
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416

        Returns:
417
            A list of `RequestOutput` objects containing the
418
            generated completions in the same order as the input prompts.
419

420
421
422
423
424
        :::{note}
        Using `prompts` and `prompt_token_ids` as keyword parameters is
        considered legacy and may be deprecated in the future. You should
        instead pass them via the `inputs` parameter.
        :::
425
        """
426
        runner_type = self.llm_engine.model_config.runner_type
427
        if runner_type not in ["generate", "transcription"]:
428
            messages = [
429
                "LLM.generate() is only supported for (conditional) generation "
430
431
432
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

433
434
435
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
436
                messages.append(
437
438
439
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
440
441

            raise ValueError(" ".join(messages))
442

443
        if prompt_token_ids is not None:
444
            parsed_prompts = self._convert_v1_inputs(
445
                prompts=cast(Optional[Union[str, list[str]]], prompts),
446
447
448
                prompt_token_ids=prompt_token_ids,
            )
        else:
449
450
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
451

452
453
454
455
456
457
458
459
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

460
461
        if sampling_params is None:
            # Use default sampling params.
462
            sampling_params = self.get_default_sampling_params()
463

464
        self._validate_and_add_requests(
465
            prompts=parsed_prompts,
466
            params=sampling_params,
467
            use_tqdm=use_tqdm,
468
            lora_request=lora_request,
469
            prompt_adapter_request=prompt_adapter_request,
470
            guided_options=guided_options_request,
471
472
            priority=priority,
        )
473

474
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
475
        return self.engine_class.validate_outputs(outputs, RequestOutput)
476

477
    def collective_rpc(self,
478
                       method: Union[str, Callable[..., _R]],
479
                       timeout: Optional[float] = None,
480
481
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
482
483
484
485
486
487
488
489
490
491
492
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
493
                {exc}`TimeoutError` on timeout. `None` means wait indefinitely.
494
495
496
497
498
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
499
500
501
502
503

        :::{note}
        It is recommended to use this API to only pass control messages,
        and set up data-plane communication to pass data.
        :::
504
        """
505
506

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
507
508

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
509
        """
510
511
        Run a function directly on the model inside each worker,
        returning the result for each of them.
512
        """
513
514
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
515

516
517
    def beam_search(
        self,
518
        prompts: list[Union[TokensPrompt, TextPrompt]],
519
        params: BeamSearchParams,
520
    ) -> list[BeamSearchOutput]:
521
522
523
524
525
526
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
527
            params: The beam search parameters.
528
        """
529
530
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
531
532
533
534
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
535
536
537
538
539
540
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
541

542
543
544
545
546
547
548
549
550
551
552
553
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
554

555
556
557
558
559
560
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
561
                                            temperature=temperature)
562
        instances: list[BeamSearchInstance] = []
563
564

        for prompt in prompts:
565
566
567
568
569
570
571
572
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

573
574
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
575
576
577
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
578

579
580
            instances.append(
                BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
581
582

        for _ in range(max_tokens):
583
            all_beams: list[BeamSearchSequence] = list(
584
585
586
587
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
588
            instance_start_and_end: list[tuple[int, int]] = list(
589
590
591
592
593
594
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
595
                create_tokens_prompt_from_beam(beam) for beam in all_beams
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
619
                                logprobs=current_beam.logprobs + [logprobs],
620
                                cum_logprob=current_beam.cum_logprob +
621
622
623
624
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
625
626
627
628
629
630
631

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
632
                                      key=sort_beams_key,
633
634
635
636
637
638
639
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
640
                                      key=sort_beams_key,
641
642
643
644
645
646
647
648
649
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
650
651
    def chat(
        self,
652
653
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
654
        sampling_params: Optional[Union[SamplingParams,
655
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
656
657
658
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
659
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
660
        add_generation_prompt: bool = True,
661
        continue_final_message: bool = False,
662
        tools: Optional[list[dict[str, Any]]] = None,
663
        chat_template_kwargs: Optional[dict[str, Any]] = None,
664
665
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
666
        """
667
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
668

669
        The chat conversation is converted into a text prompt using the
670
        tokenizer and calls the {meth}`generate` method to generate the
671
672
673
674
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
675
676

        Args:
677
678
679
680
681
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
682
683
684
685
686
687
688
689
690
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
691
692
693
694
695
696
697
698
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

699
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
700
                to each message.
701
            continue_final_message: If True, continues the final message in
702
703
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
704
705
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
706
707
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
708
709
710
711
712

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
713
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
714

715
716
        # Handle multi and single conversations
        if is_list_of(messages, list):
717
718
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
719
                                    messages)
720
        else:
721
            # messages is list[...]
722
            list_of_messages = [
723
                cast(list[ChatCompletionMessageParam], messages)
724
            ]
725

726
        tokenizer = self.get_tokenizer(lora_request)
727
728
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
729
            model_config,
730
            chat_template,
731
            tools,
732
733
734
735
            chat_template_content_format,
            tokenizer,
        )

736
737
738
739
740
741
742
743
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

744
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
745
746

        for msgs in list_of_messages:
747
748
749
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
750
            conversation, mm_data = parse_chat_messages(
751
752
753
754
755
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
756
757

            if isinstance(tokenizer, MistralTokenizer):
758
                prompt_token_ids = apply_mistral_chat_template(
759
760
                    tokenizer,
                    messages=msgs,
761
                    **_chat_template_kwargs,
762
763
                )
            else:
764
                prompt_str = apply_hf_chat_template(
765
                    model_config,
766
767
                    tokenizer,
                    conversation=conversation,
768
                    **_chat_template_kwargs,
769
                )
770
771
772
773
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
774

775
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
776
777
778
779

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

780
781
782
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

783
            prompts.append(prompt)
784

nunjunj's avatar
nunjunj committed
785
        return self.generate(
786
            prompts,
787
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
788
789
790
791
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

792
793
794
795
796
797
798
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
799
        *,
800
        truncate_prompt_tokens: Optional[int] = None,
801
        use_tqdm: bool = True,
802
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
803
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
804
    ) -> list[PoolingRequestOutput]:
805
806
        ...

807
    @overload  # LEGACY: single (prompt + optional token ids)
808
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
809
810
811
812
813
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
814
        prompt_token_ids: Optional[list[int]] = None,
815
        truncate_prompt_tokens: Optional[int] = None,
816
        use_tqdm: bool = True,
817
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
818
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
819
    ) -> list[PoolingRequestOutput]:
820
        ...
821

822
    @overload  # LEGACY: multi (prompt + optional token ids)
823
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
824
825
    def encode(
        self,
826
        prompts: list[str],
827
        pooling_params: Optional[Union[PoolingParams,
828
                                       Sequence[PoolingParams]]] = None,
829
        prompt_token_ids: Optional[list[list[int]]] = None,
830
        truncate_prompt_tokens: Optional[int] = None,
831
        use_tqdm: bool = True,
832
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
833
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
834
    ) -> list[PoolingRequestOutput]:
835
836
837
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
838
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
839
840
841
842
843
844
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
845
        prompt_token_ids: list[int],
846
        truncate_prompt_tokens: Optional[int] = None,
847
        use_tqdm: bool = True,
848
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
849
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
850
    ) -> list[PoolingRequestOutput]:
851
852
853
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
854
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
855
856
    def encode(
        self,
857
        prompts: Optional[list[str]] = None,
858
859
860
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
861
        prompt_token_ids: list[list[int]],
862
        truncate_prompt_tokens: Optional[int] = None,
863
        use_tqdm: bool = True,
864
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
865
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
866
    ) -> list[PoolingRequestOutput]:
867
868
869
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
870
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
871
872
873
874
    def encode(
        self,
        prompts: None,
        pooling_params: None,
875
        prompt_token_ids: Union[list[int], list[list[int]]],
876
        truncate_prompt_tokens: Optional[int] = None,
877
        use_tqdm: bool = True,
878
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
879
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
880
    ) -> list[PoolingRequestOutput]:
881
882
        ...

nunjunj's avatar
nunjunj committed
883
884
885
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
886
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
887
    )
888
889
    def encode(
        self,
890
        prompts: Union[Union[PromptType, Sequence[PromptType]],
891
                       Optional[Union[str, list[str]]]] = None,
892
893
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
894
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
895
        truncate_prompt_tokens: Optional[int] = None,
896
        use_tqdm: bool = True,
897
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
898
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
899
    ) -> list[PoolingRequestOutput]:
900
901
        """Apply pooling to the hidden states corresponding to the input
        prompts.
902

903
        This class automatically batches the given prompts, considering
904
905
906
907
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
908
            prompts: The prompts to the LLM. You may pass a sequence of prompts
909
                for batch inference. See {class}`~vllm.inputs.PromptType`
910
                for more details about the format of each prompts.
911
912
913
914
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
915
            prompt_adapter_request: Prompt Adapter request to use for
916
                generation, if any.
917
918

        Returns:
919
            A list of `PoolingRequestOutput` objects containing the
920
            pooled hidden states in the same order as the input prompts.
921

922
923
924
925
926
        :::{note}
        Using `prompts` and `prompt_token_ids` as keyword parameters is
        considered legacy and may be deprecated in the future. You should
        instead pass them via the `inputs` parameter.
        :::
927
        """
928
929
930
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
931

932
933
934
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
935
                messages.append(
936
937
938
939
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
940
941

            raise ValueError(" ".join(messages))
942

943
        if prompt_token_ids is not None:
944
            parsed_prompts = self._convert_v1_inputs(
945
                prompts=cast(Optional[Union[str, list[str]]], prompts),
946
947
948
                prompt_token_ids=prompt_token_ids,
            )
        else:
949
950
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
951

952
953
954
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
955
956
957
958
959
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
960

961
962
963
964
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

965
        self._validate_and_add_requests(
966
            prompts=parsed_prompts,
967
            params=pooling_params,
968
            use_tqdm=use_tqdm,
969
            lora_request=lora_request,
970
            tokenization_kwargs=tokenization_kwargs,
971
            prompt_adapter_request=prompt_adapter_request,
972
973
        )

974
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
975
        return self.engine_class.validate_outputs(outputs,
976
                                                  PoolingRequestOutput)
977

978
979
980
981
982
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
983
        truncate_prompt_tokens: Optional[int] = None,
984
        use_tqdm: bool = True,
985
986
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
987
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
988
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
989
    ) -> list[EmbeddingRequestOutput]:
990
991
992
993
994
995
996
997
998
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
999
                for batch inference. See {class}`~vllm.inputs.PromptType`
1000
                for more details about the format of each prompts.
1001
1002
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1017
                            truncate_prompt_tokens=truncate_prompt_tokens,
1018
                            use_tqdm=use_tqdm,
1019
                            pooling_params=pooling_params,
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1031
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1032
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1033
    ) -> list[ClassificationRequestOutput]:
1034
1035
1036
1037
1038
1039
1040
1041
1042
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1043
                for batch inference. See {class}`~vllm.inputs.PromptType`
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1065
1066
1067
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1068
1069
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1070
1071
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1072
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1073
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1074
    ) -> list[ScoringRequestOutput]:
1075

1076
        encoded_output: list[PoolingRequestOutput] = self.encode(
1077
            text_1 + text_2,
1078
            truncate_prompt_tokens=truncate_prompt_tokens,
1079
1080
1081
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1082

1083
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1084
            0:len(text_1)]
1085
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1086
            len(text_1):]
1087
1088
1089
1090

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1091
1092
1093
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1094
1095
1096
1097
1098
1099
1100

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1101
        tokenizer: AnyTokenizer,
1102
1103
        text_1: list[str],
        text_2: list[str],
1104
1105
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1106
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1107
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1108
    ) -> list[ScoringRequestOutput]:
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1121
        tokenization_kwargs: dict[str, Any] = {}
1122
1123
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1139
            use_tqdm=use_tqdm,
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1150
1151
1152
1153
1154
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1155
        *,
1156
1157
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1158
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1159
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1160
    ) -> list[ScoringRequestOutput]:
1161
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1162

1163
1164
1165
1166
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1167
1168
1169
1170
1171
1172
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1173
                case it has to have the same length as the ``text_2`` list
1174
            text_2: The texts to pair with the query to form the input
1175
                to the LLM. See {class}`~vllm.inputs.PromptType` for
1176
1177
1178
1179
1180
1181
1182
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1183
            A list of ``ScoringRequestOutput`` objects containing the
1184
1185
            generated scores in the same order as the input prompts.
        """
1186
1187
1188
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1189

1190
1191
1192
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1193
                messages.append(
1194
1195
1196
1197
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1198
1199
1200

            raise ValueError(" ".join(messages))

1201
        if self.llm_engine.model_config.task not in ("embed", "score"):
1202
            raise ValueError(
1203
                "Score API is only enabled for `--task embed or --task score`")
1204
1205
1206
1207

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1208
1209
        tokenizer = self.llm_engine.get_tokenizer()

1210
1211
1212
1213
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1214
                                     "supported for scoring")
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1226
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1227
1228
1229
1230

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1231
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1232

1233
        _validate_score_input_lens(input_text_1, input_text_2)
1234

1235
        if self.llm_engine.model_config.is_cross_encoder:
1236
1237
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1238
1239
1240
1241
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1242
1243
1244
1245
1246
1247
1248
1249
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1250

1251
1252
1253
1254
1255
1256
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1257
1258
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1259

1260
1261
1262
1263
1264
1265
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1278
        """
1279
        self.reset_prefix_cache()
1280
1281
        self.llm_engine.sleep(level=level)

1282
    def wake_up(self, tags: Optional[list[str]] = None):
1283
        """
1284
        Wake up the engine from sleep mode. See the {meth}`sleep` method
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
                ("weights", "kv_cache",). If None, all memory is reallocated.
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1295

1296
1297
    # LEGACY
    def _convert_v1_inputs(
1298
        self,
1299
1300
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1301
1302
    ):
        # skip_tokenizer_init is now checked in engine
1303

1304
1305
1306
1307
1308
1309
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1310

1311
        num_requests = None
1312
1313
        if prompts is not None:
            num_requests = len(prompts)
1314
1315
1316
1317
1318
1319
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1320
            num_requests = len(prompt_token_ids)
1321
1322
1323
1324
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1325
        parsed_prompts: list[PromptType] = []
1326
        for i in range(num_requests):
1327
            item: PromptType
1328

1329
            if prompts is not None:
1330
1331
1332
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1333
            else:
1334
                raise AssertionError
1335

1336
            parsed_prompts.append(item)
1337

1338
        return parsed_prompts
1339
1340
1341

    def _validate_and_add_requests(
        self,
1342
        prompts: Union[PromptType, Sequence[PromptType]],
1343
1344
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1345
1346
        *,
        use_tqdm: bool,
1347
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1348
        prompt_adapter_request: Optional[PromptAdapterRequest],
1349
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1350
        guided_options: Optional[GuidedDecodingRequest] = None,
1351
        priority: Optional[list[int]] = None,
1352
    ) -> None:
1353
1354
1355
1356
1357
1358
1359
1360
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1361
        if isinstance(prompts, (str, dict)):
1362
            # Convert a single prompt to a list.
1363
            prompts = [prompts]
1364

1365
        num_requests = len(prompts)
1366
1367
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1368
                             "must be the same.")
1369
1370
1371
1372
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1373

1374
1375
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1376
                self._add_guided_params(sp, guided_options)
1377
1378
1379

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1380

Zhuohan Li's avatar
Zhuohan Li committed
1381
        # Add requests to the engine.
1382
1383
1384
1385
1386
        it = prompts
        if use_tqdm:
            it = tqdm(it, desc="Adding requests")

        for i, prompt in enumerate(it):
1387
            self._add_request(
1388
                prompt,
1389
                params[i] if isinstance(params, Sequence) else params,
1390
                tokenization_kwargs=tokenization_kwargs,
1391
1392
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1393
                prompt_adapter_request=prompt_adapter_request,
1394
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1395
            )
1396

1397
    def _add_request(
nunjunj's avatar
nunjunj committed
1398
        self,
1399
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1400
        params: Union[SamplingParams, PoolingParams],
1401
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1402
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1403
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1404
        priority: int = 0,
1405
1406
    ) -> None:
        request_id = str(next(self.request_counter))
1407
1408
        self.llm_engine.add_request(
            request_id,
1409
            prompt,
1410
1411
            params,
            lora_request=lora_request,
1412
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1413
            prompt_adapter_request=prompt_adapter_request,
1414
            priority=priority,
nunjunj's avatar
nunjunj committed
1415
        )
1416

1417
    def _add_guided_params(
1418
1419
1420
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1421
1422
1423
1424
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1425
            raise ValueError("Cannot set both guided_options_request and "
1426
1427
1428
1429
1430
1431
1432
1433
1434
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1435
1436
1437
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1438
1439
        return params

1440
    def _run_engine(
1441
            self, *, use_tqdm: bool
1442
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1443
1444
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1445
            num_requests = self.llm_engine.get_num_unfinished_requests()
1446
1447
1448
1449
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1450
1451
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1452
            )
1453

Zhuohan Li's avatar
Zhuohan Li committed
1454
        # Run the engine.
1455
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1456
1457
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1458
1459
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1460
            for output in step_outputs:
1461
                if output.finished:
1462
1463
                    outputs.append(output)
                    if use_tqdm:
1464
1465
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1466
                            n = len(output.outputs)
1467
                            assert output.prompt_token_ids is not None
1468
                            total_in_toks += len(output.prompt_token_ids) * n
1469
1470
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1471
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1472
1473
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1474
1475
1476
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1477
                            pbar.update(n)
1478
1479
                        else:
                            pbar.update(1)
1480

1481
1482
        if use_tqdm:
            pbar.close()
1483
1484
1485
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1486
        return sorted(outputs, key=lambda x: int(x.request_id))