llm.py 63.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
17
18
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
19
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
20
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
21
                                         ChatTemplateContentFormatOption,
22
23
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
24
25
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
26
27
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
28
from vllm.entrypoints.utils import _validate_truncation_size
29
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
30
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
31
from vllm.logger import init_logger
32
from vllm.lora.request import LoRARequest
33
34
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
35
from vllm.model_executor.layers.quantization import QuantizationMethods
36
37
38
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
39
from vllm.pooling_params import PoolingParams
40
from vllm.prompt_adapter.request import PromptAdapterRequest
41
42
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
43
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
44
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
45
from vllm.usage.usage_lib import UsageContext
46
47
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
48

49
50
logger = init_logger(__name__)

51
52
_R = TypeVar("_R", default=Any)

53
54

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
58
59
60
61
62
63
64
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
65
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
66
67
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
68
69
70
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
71
72
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
73
74
75
76
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
82
83
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
84
        quantization: The method used to quantize the model weights. Currently,
85
            we support "awq", "gptq", and "fp8" (experimental).
86
87
88
89
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
90
91
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
92
93
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
94
95
96
97
98
99
100
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
101
102
103
104
105
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
106
107
108
109
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
110
111
112
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
113
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
114
            When a sequence has context length larger than this, we fall back
115
116
117
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
118
119
120
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
121
        hf_token: The token to use as HTTP bearer authorization for remote files
122
            . If `True`, will use the token generated when running
123
            `huggingface-cli login` (stored in `~/.huggingface`).
124
125
126
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
127
128
129
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
130
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
131
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
132

133
134
135
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
136
    """
137

138
    DEPRECATE_LEGACY: ClassVar[bool] = True
139
140
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

141
142
143
144
145
146
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

147
148
149
150
151
152
153
154
155
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

156
157
158
159
160
161
162
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
163
164
165
    def __init__(
        self,
        model: str,
166
        tokenizer: Optional[str] = None,
167
        tokenizer_mode: TokenizerMode = "auto",
168
        skip_tokenizer_init: bool = False,
169
        trust_remote_code: bool = False,
170
        allowed_local_media_path: str = "",
171
        tensor_parallel_size: int = 1,
172
173
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
174
        revision: Optional[str] = None,
175
        tokenizer_revision: Optional[str] = None,
176
        seed: Optional[int] = None,
177
        gpu_memory_utilization: float = 0.9,
178
        swap_space: float = 4,
179
        cpu_offload_gb: float = 0,
180
        enforce_eager: bool = False,
181
        max_seq_len_to_capture: int = 8192,
182
        disable_custom_all_reduce: bool = False,
183
        disable_async_output_proc: bool = False,
184
        hf_token: Optional[Union[bool, str]] = None,
185
        hf_overrides: Optional[HfOverrides] = None,
186
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
187
188
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
189
        override_pooler_config: Optional[PoolerConfig] = None,
190
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
191
192
        **kwargs,
    ) -> None:
193
        """LLM constructor."""
194

195
196
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
197

198
199
200
201
202
203
204
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

205
        if compilation_config is not None:
206
            if isinstance(compilation_config, (int, dict)):
207
208
209
210
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
211
212
213
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
214
        engine_args = EngineArgs(
215
            model=model,
216
            task=task,
217
            tokenizer=tokenizer,
218
            tokenizer_mode=tokenizer_mode,
219
            skip_tokenizer_init=skip_tokenizer_init,
220
            trust_remote_code=trust_remote_code,
221
            allowed_local_media_path=allowed_local_media_path,
222
223
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
224
            quantization=quantization,
225
            revision=revision,
226
            tokenizer_revision=tokenizer_revision,
227
228
229
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
230
            cpu_offload_gb=cpu_offload_gb,
231
            enforce_eager=enforce_eager,
232
            max_seq_len_to_capture=max_seq_len_to_capture,
233
            disable_custom_all_reduce=disable_custom_all_reduce,
234
            disable_async_output_proc=disable_async_output_proc,
235
            hf_token=hf_token,
236
            hf_overrides=hf_overrides,
237
            mm_processor_kwargs=mm_processor_kwargs,
238
            override_pooler_config=override_pooler_config,
239
            compilation_config=compilation_config_instance,
240
241
            **kwargs,
        )
242
243
244
245
246

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
247

248
        self.request_counter = Counter()
249
        self.default_sampling_params: Union[dict[str, Any], None] = None
250

251
252
253
254
255
256
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
257
258

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
259
        tokenizer_group = self.llm_engine.get_tokenizer_group()
260

261
262
263
264
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
265
            tokenizer_group.tokenizer = tokenizer
266
        else:
267
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
268

269
    def get_default_sampling_params(self) -> SamplingParams:
270
271
272
273
274
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
275
276
        return SamplingParams()

277
278
279
280
281
282
283
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
284
        *,
285
        use_tqdm: bool = True,
286
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
287
288
289
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
290
    ) -> list[RequestOutput]:
291
292
        ...

293
    @overload  # LEGACY: single (prompt + optional token ids)
294
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
295
296
297
298
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
299
300
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
301
        use_tqdm: bool = True,
302
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
303
304
305
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
306
    ) -> list[RequestOutput]:
307
308
309
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
310
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
311
312
    def generate(
        self,
313
        prompts: list[str],
314
        sampling_params: Optional[Union[SamplingParams,
315
316
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
317
        use_tqdm: bool = True,
318
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
319
320
321
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
322
    ) -> list[RequestOutput]:
323
324
325
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
326
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
327
328
329
330
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
331
                                        list[SamplingParams]]] = None,
332
        *,
333
        prompt_token_ids: list[int],
334
        use_tqdm: bool = True,
335
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
336
337
338
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
339
    ) -> list[RequestOutput]:
340
341
342
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
343
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
344
345
    def generate(
        self,
346
        prompts: Optional[list[str]] = None,
347
        sampling_params: Optional[Union[SamplingParams,
348
                                        list[SamplingParams]]] = None,
349
        *,
350
        prompt_token_ids: list[list[int]],
351
        use_tqdm: bool = True,
352
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
353
354
355
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
356
    ) -> list[RequestOutput]:
357
358
359
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
360
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
361
362
363
364
    def generate(
        self,
        prompts: None,
        sampling_params: None,
365
        prompt_token_ids: Union[list[int], list[list[int]]],
366
        use_tqdm: bool = True,
367
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
368
369
370
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
371
    ) -> list[RequestOutput]:
372
373
        ...

nunjunj's avatar
nunjunj committed
374
375
376
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
377
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
378
    )
379
380
    def generate(
        self,
381
        prompts: Union[Union[PromptType, Sequence[PromptType]],
382
                       Optional[Union[str, list[str]]]] = None,
383
384
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
385
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
386
        use_tqdm: bool = True,
387
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
388
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
389
        guided_options_request: Optional[Union[LLMGuidedOptions,
390
                                               GuidedDecodingRequest]] = None,
391
392
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
393
394
        """Generates the completions for the input prompts.

395
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
396
397
398
399
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
400
401
402
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
403
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
404
405
406
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
407
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
408
            use_tqdm: Whether to use tqdm to display the progress bar.
409
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
410
            prompt_adapter_request: Prompt Adapter request to use for
411
                generation, if any.
412
413
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
414
415

        Returns:
nunjunj's avatar
nunjunj committed
416
            A list of ``RequestOutput`` objects containing the
417
            generated completions in the same order as the input prompts.
418
419
420
421
422

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
423
        """
424
        runner_type = self.llm_engine.model_config.runner_type
425
        if runner_type not in ["generate", "transcription"]:
426
            messages = [
427
                "LLM.generate() is only supported for (conditional) generation "
428
429
430
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

431
432
433
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
434
                messages.append(
435
436
437
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
438
439

            raise ValueError(" ".join(messages))
440

441
        if prompt_token_ids is not None:
442
            parsed_prompts = self._convert_v1_inputs(
443
                prompts=cast(Optional[Union[str, list[str]]], prompts),
444
445
446
                prompt_token_ids=prompt_token_ids,
            )
        else:
447
448
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
449

450
451
452
453
454
455
456
457
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

458
459
        if sampling_params is None:
            # Use default sampling params.
460
            sampling_params = self.get_default_sampling_params()
461

462
        self._validate_and_add_requests(
463
            prompts=parsed_prompts,
464
465
            params=sampling_params,
            lora_request=lora_request,
466
            prompt_adapter_request=prompt_adapter_request,
467
468
            guided_options=guided_options_request,
            priority=priority)
469

470
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
471
        return self.engine_class.validate_outputs(outputs, RequestOutput)
472

473
    def collective_rpc(self,
474
                       method: Union[str, Callable[..., _R]],
475
                       timeout: Optional[float] = None,
476
477
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
500
501

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
502
503

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
504
        """
505
506
        Run a function directly on the model inside each worker,
        returning the result for each of them.
507
        """
508
509
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
510

511
512
    def beam_search(
        self,
513
        prompts: list[Union[TokensPrompt, TextPrompt]],
514
        params: BeamSearchParams,
515
    ) -> list[BeamSearchOutput]:
516
517
518
519
520
521
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
522
            params: The beam search parameters.
523
        """
524
525
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
526
527
528
529
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
530
531
532
533
534
535
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
536

537
538
539
540
541
542
543
544
545
546
547
548
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
549

550
551
552
553
554
555
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
556
                                            temperature=temperature)
557
        instances: list[BeamSearchInstance] = []
558
559

        for prompt in prompts:
560
561
562
563
564
565
566
567
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

568
569
570
571
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
572
573
            instances.append(
                BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
574
575

        for _ in range(max_tokens):
576
            all_beams: list[BeamSearchSequence] = list(
577
578
579
580
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
581
            instance_start_and_end: list[tuple[int, int]] = list(
582
583
584
585
586
587
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
588
                create_tokens_prompt_from_beam(beam) for beam in all_beams
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
612
                                logprobs=current_beam.logprobs + [logprobs],
613
                                cum_logprob=current_beam.cum_logprob +
614
615
616
617
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
618
619
620
621
622
623
624

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
625
                                      key=sort_beams_key,
626
627
628
629
630
631
632
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
633
                                      key=sort_beams_key,
634
635
636
637
638
639
640
641
642
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
643
644
    def chat(
        self,
645
646
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
647
        sampling_params: Optional[Union[SamplingParams,
648
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
649
650
651
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
652
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
653
        add_generation_prompt: bool = True,
654
        continue_final_message: bool = False,
655
        tools: Optional[list[dict[str, Any]]] = None,
656
        chat_template_kwargs: Optional[dict[str, Any]] = None,
657
658
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
659
        """
660
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
661

662
663
664
665
666
667
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
668
669

        Args:
670
671
672
673
674
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
675
676
677
678
679
680
681
682
683
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
684
685
686
687
688
689
690
691
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

692
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
693
                to each message.
694
            continue_final_message: If True, continues the final message in
695
696
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
697
698
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
699
700
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
701
702
703
704
705

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
706
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
707

708
709
        # Handle multi and single conversations
        if is_list_of(messages, list):
710
711
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
712
                                    messages)
713
        else:
714
            # messages is list[...]
715
            list_of_messages = [
716
                cast(list[ChatCompletionMessageParam], messages)
717
            ]
718

719
        tokenizer = self.get_tokenizer(lora_request)
720
721
722
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
723
            tools,
724
725
            chat_template_content_format,
            tokenizer,
726
            trust_remote_code=model_config.trust_remote_code,
727
728
        )

729
730
731
732
733
734
735
736
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

737
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
738
739

        for msgs in list_of_messages:
740
741
742
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
743
            conversation, mm_data = parse_chat_messages(
744
745
746
747
748
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
749
750

            if isinstance(tokenizer, MistralTokenizer):
751
                prompt_token_ids = apply_mistral_chat_template(
752
753
                    tokenizer,
                    messages=msgs,
754
                    **_chat_template_kwargs,
755
756
                )
            else:
757
                prompt_str = apply_hf_chat_template(
758
                    tokenizer,
759
                    trust_remote_code=model_config.trust_remote_code,
760
                    conversation=conversation,
761
                    **_chat_template_kwargs,
762
                )
763
764
765
766
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
767

768
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
769
770
771
772

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

773
774
775
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

776
            prompts.append(prompt)
777

nunjunj's avatar
nunjunj committed
778
        return self.generate(
779
            prompts,
780
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
781
782
783
784
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

785
786
787
788
789
790
791
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
792
        *,
793
        truncate_prompt_tokens: Optional[int] = None,
794
        use_tqdm: bool = True,
795
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
796
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
797
    ) -> list[PoolingRequestOutput]:
798
799
        ...

800
    @overload  # LEGACY: single (prompt + optional token ids)
801
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
802
803
804
805
806
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
807
        prompt_token_ids: Optional[list[int]] = None,
808
        truncate_prompt_tokens: Optional[int] = None,
809
        use_tqdm: bool = True,
810
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
811
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
812
    ) -> list[PoolingRequestOutput]:
813
        ...
814

815
    @overload  # LEGACY: multi (prompt + optional token ids)
816
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
817
818
    def encode(
        self,
819
        prompts: list[str],
820
        pooling_params: Optional[Union[PoolingParams,
821
                                       Sequence[PoolingParams]]] = None,
822
        prompt_token_ids: Optional[list[list[int]]] = None,
823
        truncate_prompt_tokens: Optional[int] = None,
824
        use_tqdm: bool = True,
825
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
826
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
827
    ) -> list[PoolingRequestOutput]:
828
829
830
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
831
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
832
833
834
835
836
837
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
838
        prompt_token_ids: list[int],
839
        truncate_prompt_tokens: Optional[int] = None,
840
        use_tqdm: bool = True,
841
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
842
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
843
    ) -> list[PoolingRequestOutput]:
844
845
846
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
847
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
848
849
    def encode(
        self,
850
        prompts: Optional[list[str]] = None,
851
852
853
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
854
        prompt_token_ids: list[list[int]],
855
        truncate_prompt_tokens: Optional[int] = None,
856
        use_tqdm: bool = True,
857
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
858
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
859
    ) -> list[PoolingRequestOutput]:
860
861
862
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
863
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
864
865
866
867
    def encode(
        self,
        prompts: None,
        pooling_params: None,
868
        prompt_token_ids: Union[list[int], list[list[int]]],
869
        truncate_prompt_tokens: Optional[int] = None,
870
        use_tqdm: bool = True,
871
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
872
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
873
    ) -> list[PoolingRequestOutput]:
874
875
        ...

nunjunj's avatar
nunjunj committed
876
877
878
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
879
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
880
    )
881
882
    def encode(
        self,
883
        prompts: Union[Union[PromptType, Sequence[PromptType]],
884
                       Optional[Union[str, list[str]]]] = None,
885
886
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
887
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
888
        truncate_prompt_tokens: Optional[int] = None,
889
        use_tqdm: bool = True,
890
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
891
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
892
    ) -> list[PoolingRequestOutput]:
893
894
        """Apply pooling to the hidden states corresponding to the input
        prompts.
895

896
        This class automatically batches the given prompts, considering
897
898
899
900
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
901
902
903
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
904
905
906
907
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
908
            prompt_adapter_request: Prompt Adapter request to use for
909
                generation, if any.
910
911

        Returns:
912
            A list of ``PoolingRequestOutput`` objects containing the
913
            pooled hidden states in the same order as the input prompts.
914
915
916
917
918

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
919
        """
920
921
922
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
923

924
925
926
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
927
                messages.append(
928
929
930
931
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
932
933

            raise ValueError(" ".join(messages))
934

935
        if prompt_token_ids is not None:
936
            parsed_prompts = self._convert_v1_inputs(
937
                prompts=cast(Optional[Union[str, list[str]]], prompts),
938
939
940
                prompt_token_ids=prompt_token_ids,
            )
        else:
941
942
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
943

944
945
946
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
947
948
949
950
951
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
952

953
954
955
956
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

957
        self._validate_and_add_requests(
958
            prompts=parsed_prompts,
959
960
            params=pooling_params,
            lora_request=lora_request,
961
            tokenization_kwargs=tokenization_kwargs,
962
            prompt_adapter_request=prompt_adapter_request,
963
964
        )

965
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
966
        return self.engine_class.validate_outputs(outputs,
967
                                                  PoolingRequestOutput)
968

969
970
971
972
973
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
974
        truncate_prompt_tokens: Optional[int] = None,
975
        use_tqdm: bool = True,
976
977
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
978
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
979
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
980
    ) -> list[EmbeddingRequestOutput]:
981
982
983
984
985
986
987
988
989
990
991
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
992
993
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1008
                            truncate_prompt_tokens=truncate_prompt_tokens,
1009
                            use_tqdm=use_tqdm,
1010
                            pooling_params=pooling_params,
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1022
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1023
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1024
    ) -> list[ClassificationRequestOutput]:
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1056
1057
1058
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1059
1060
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1061
1062
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1063
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1064
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1065
    ) -> list[ScoringRequestOutput]:
1066

1067
        encoded_output: list[PoolingRequestOutput] = self.encode(
1068
            text_1 + text_2,
1069
            truncate_prompt_tokens=truncate_prompt_tokens,
1070
1071
1072
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1073

1074
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1075
            0:len(text_1)]
1076
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1077
            len(text_1):]
1078
1079
1080
1081

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1082
1083
1084
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1085
1086
1087
1088
1089
1090
1091

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1092
        tokenizer: AnyTokenizer,
1093
1094
        text_1: list[str],
        text_2: list[str],
1095
1096
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1097
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1098
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1099
    ) -> list[ScoringRequestOutput]:
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1112
        tokenization_kwargs: dict[str, Any] = {}
1113
1114
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1140
1141
1142
1143
1144
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1145
        *,
1146
1147
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1148
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1149
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1150
    ) -> list[ScoringRequestOutput]:
1151
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1152

1153
1154
1155
1156
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1157
1158
1159
1160
1161
1162
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1163
                case it has to have the same length as the ``text_2`` list
1164
1165
1166
1167
1168
1169
1170
1171
1172
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1173
            A list of ``ScoringRequestOutput`` objects containing the
1174
1175
            generated scores in the same order as the input prompts.
        """
1176
1177
1178
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1179

1180
1181
1182
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1183
                messages.append(
1184
1185
1186
1187
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1188
1189
1190

            raise ValueError(" ".join(messages))

1191
        if self.llm_engine.model_config.task not in ("embed", "score"):
1192
            raise ValueError(
1193
                "Score API is only enabled for `--task embed or --task score`")
1194
1195
1196
1197

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1198
1199
        tokenizer = self.llm_engine.get_tokenizer()

1200
1201
1202
1203
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1204
                                     "supported for scoring")
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1216
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1217
1218
1219
1220

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1221
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1222

1223
        _validate_score_input_lens(input_text_1, input_text_2)
1224

1225
        if self.llm_engine.model_config.is_cross_encoder:
1226
1227
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1228
1229
1230
1231
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1232
1233
1234
1235
1236
1237
1238
1239
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1240

1241
1242
1243
1244
1245
1246
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1247
1248
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1249

1250
1251
1252
1253
1254
1255
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1268
        """
1269
        self.reset_prefix_cache()
1270
1271
        self.llm_engine.sleep(level=level)

1272
    def wake_up(self, tags: Optional[list[str]] = None):
1273
1274
        """
        Wake up the engine from sleep mode. See the :meth:`sleep` method
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
                ("weights", "kv_cache",). If None, all memory is reallocated.
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1285

1286
1287
    # LEGACY
    def _convert_v1_inputs(
1288
        self,
1289
1290
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1291
1292
    ):
        # skip_tokenizer_init is now checked in engine
1293

1294
1295
1296
1297
1298
1299
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1300

1301
        num_requests = None
1302
1303
        if prompts is not None:
            num_requests = len(prompts)
1304
1305
1306
1307
1308
1309
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1310
            num_requests = len(prompt_token_ids)
1311
1312
1313
1314
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1315
        parsed_prompts: list[PromptType] = []
1316
        for i in range(num_requests):
1317
            item: PromptType
1318

1319
            if prompts is not None:
1320
1321
1322
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1323
            else:
1324
                raise AssertionError
1325

1326
            parsed_prompts.append(item)
1327

1328
        return parsed_prompts
1329
1330
1331

    def _validate_and_add_requests(
        self,
1332
        prompts: Union[PromptType, Sequence[PromptType]],
1333
1334
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1335
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1336
        prompt_adapter_request: Optional[PromptAdapterRequest],
1337
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1338
        guided_options: Optional[GuidedDecodingRequest] = None,
1339
        priority: Optional[list[int]] = None,
1340
    ) -> None:
1341
1342
1343
1344
1345
1346
1347
1348
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1349
        if isinstance(prompts, (str, dict)):
1350
            # Convert a single prompt to a list.
1351
            prompts = [prompts]
1352

1353
        num_requests = len(prompts)
1354
1355
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1356
                             "must be the same.")
1357
1358
1359
1360
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1361

1362
1363
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1364
                self._add_guided_params(sp, guided_options)
1365
1366
1367

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1368

Zhuohan Li's avatar
Zhuohan Li committed
1369
        # Add requests to the engine.
1370
        for i, prompt in enumerate(prompts):
1371
            self._add_request(
1372
                prompt,
1373
                params[i] if isinstance(params, Sequence) else params,
1374
                tokenization_kwargs=tokenization_kwargs,
1375
1376
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1377
                prompt_adapter_request=prompt_adapter_request,
1378
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1379
            )
1380

1381
    def _add_request(
nunjunj's avatar
nunjunj committed
1382
        self,
1383
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1384
        params: Union[SamplingParams, PoolingParams],
1385
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1386
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1387
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1388
        priority: int = 0,
1389
1390
    ) -> None:
        request_id = str(next(self.request_counter))
1391
1392
        self.llm_engine.add_request(
            request_id,
1393
            prompt,
1394
1395
            params,
            lora_request=lora_request,
1396
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1397
            prompt_adapter_request=prompt_adapter_request,
1398
            priority=priority,
nunjunj's avatar
nunjunj committed
1399
        )
1400

1401
    def _add_guided_params(
1402
1403
1404
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1405
1406
1407
1408
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1409
            raise ValueError("Cannot set both guided_options_request and "
1410
1411
1412
1413
1414
1415
1416
1417
1418
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1419
1420
1421
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1422
1423
        return params

1424
    def _run_engine(
1425
            self, *, use_tqdm: bool
1426
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1427
1428
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1429
            num_requests = self.llm_engine.get_num_unfinished_requests()
1430
1431
1432
1433
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1434
1435
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1436
            )
1437

Zhuohan Li's avatar
Zhuohan Li committed
1438
        # Run the engine.
1439
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1440
1441
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1442
1443
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1444
            for output in step_outputs:
1445
                if output.finished:
1446
1447
                    outputs.append(output)
                    if use_tqdm:
1448
1449
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1450
                            n = len(output.outputs)
1451
                            assert output.prompt_token_ids is not None
1452
                            total_in_toks += len(output.prompt_token_ids) * n
1453
1454
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1455
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1456
1457
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1458
1459
1460
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1461
                            pbar.update(n)
1462
1463
                        else:
                            pbar.update(1)
1464

1465
1466
        if use_tqdm:
            pbar.close()
1467
1468
1469
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1470
        return sorted(outputs, key=lambda x: int(x.request_id))