"vllm/vscode:/vscode.git/clone" did not exist on "56531b79ccc746bb579a49411f32be31bc307d4b"
llm.py 62.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
from vllm.config import CompilationConfig
17
18
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
19
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
20
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
21
                                         ChatTemplateContentFormatOption,
22
23
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
24
25
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
26
27
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
28
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
29
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
30
from vllm.logger import init_logger
31
from vllm.lora.request import LoRARequest
32
33
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
34
35
36
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
37
from vllm.pooling_params import PoolingParams
38
from vllm.prompt_adapter.request import PromptAdapterRequest
39
40
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
41
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
42
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
43
from vllm.usage.usage_lib import UsageContext
44
45
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
46

47
48
logger = init_logger(__name__)

49
50
_R = TypeVar("_R", default=Any)

51
52

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55
56
57
58
59
60
61
62
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
63
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
64
65
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
66
67
68
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
69
70
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
71
72
73
74
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
82
        quantization: The method used to quantize the model weights. Currently,
83
            we support "awq", "gptq", and "fp8" (experimental).
84
85
86
87
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
88
89
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
90
91
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
92
93
94
95
96
97
98
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
99
100
101
102
103
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
104
105
106
107
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
108
109
110
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
111
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
112
            When a sequence has context length larger than this, we fall back
113
114
115
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
116
117
118
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
119
        hf_token: The token to use as HTTP bearer authorization for remote files
120
            . If `True`, will use the token generated when running
121
            `huggingface-cli login` (stored in `~/.huggingface`).
122
123
124
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
125
126
127
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
128
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
129
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
130

131
132
133
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
134
    """
135

136
    DEPRECATE_LEGACY: ClassVar[bool] = True
137
138
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

139
140
141
142
143
144
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

145
146
147
148
149
150
151
152
153
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

154
155
156
157
158
159
160
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
161
162
163
    def __init__(
        self,
        model: str,
164
        tokenizer: Optional[str] = None,
165
        tokenizer_mode: str = "auto",
166
        skip_tokenizer_init: bool = False,
167
        trust_remote_code: bool = False,
168
        allowed_local_media_path: str = "",
169
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
170
        dtype: str = "auto",
171
        quantization: Optional[str] = None,
172
        revision: Optional[str] = None,
173
        tokenizer_revision: Optional[str] = None,
174
        seed: Optional[int] = None,
175
        gpu_memory_utilization: float = 0.9,
176
        swap_space: float = 4,
177
        cpu_offload_gb: float = 0,
178
        enforce_eager: Optional[bool] = None,
179
        max_seq_len_to_capture: int = 8192,
180
        disable_custom_all_reduce: bool = False,
181
        disable_async_output_proc: bool = False,
182
        hf_token: Optional[Union[bool, str]] = None,
183
        hf_overrides: Optional[HfOverrides] = None,
184
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
185
186
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
187
        override_pooler_config: Optional[PoolerConfig] = None,
188
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
189
190
        **kwargs,
    ) -> None:
191
192
193
194
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
195
        it defaults to False.
196
197
        '''

198
199
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
200

201
202
203
204
205
206
207
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

208
        if compilation_config is not None:
209
            if isinstance(compilation_config, (int, dict)):
210
211
212
213
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
214
215
216
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
217
        engine_args = EngineArgs(
218
            model=model,
219
            task=task,
220
            tokenizer=tokenizer,
221
            tokenizer_mode=tokenizer_mode,
222
            skip_tokenizer_init=skip_tokenizer_init,
223
            trust_remote_code=trust_remote_code,
224
            allowed_local_media_path=allowed_local_media_path,
225
226
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
227
            quantization=quantization,
228
            revision=revision,
229
            tokenizer_revision=tokenizer_revision,
230
231
232
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
233
            cpu_offload_gb=cpu_offload_gb,
234
            enforce_eager=enforce_eager,
235
            max_seq_len_to_capture=max_seq_len_to_capture,
236
            disable_custom_all_reduce=disable_custom_all_reduce,
237
            disable_async_output_proc=disable_async_output_proc,
238
            hf_token=hf_token,
239
            hf_overrides=hf_overrides,
240
            mm_processor_kwargs=mm_processor_kwargs,
241
            override_pooler_config=override_pooler_config,
242
            compilation_config=compilation_config_instance,
243
244
            **kwargs,
        )
245
246
247
248
249

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
250

251
        self.request_counter = Counter()
252
        self.default_sampling_params: Union[dict[str, Any], None] = None
253

254
255
256
257
258
259
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
260
261

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
262
        tokenizer_group = self.llm_engine.get_tokenizer_group()
263

264
265
266
267
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
268
            tokenizer_group.tokenizer = tokenizer
269
        else:
270
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
271

272
    def get_default_sampling_params(self) -> SamplingParams:
273
274
275
276
277
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
278
279
        return SamplingParams()

280
281
282
283
284
285
286
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
287
        *,
288
        use_tqdm: bool = True,
289
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
290
291
292
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
293
    ) -> list[RequestOutput]:
294
295
        ...

296
    @overload  # LEGACY: single (prompt + optional token ids)
297
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
298
299
300
301
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
302
303
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
304
        use_tqdm: bool = True,
305
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
306
307
308
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
309
    ) -> list[RequestOutput]:
310
311
312
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
313
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
314
315
    def generate(
        self,
316
        prompts: list[str],
317
        sampling_params: Optional[Union[SamplingParams,
318
319
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
320
        use_tqdm: bool = True,
321
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
322
323
324
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
325
    ) -> list[RequestOutput]:
326
327
328
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
329
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
330
331
332
333
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
334
                                        list[SamplingParams]]] = None,
335
        *,
336
        prompt_token_ids: list[int],
337
        use_tqdm: bool = True,
338
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
339
340
341
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
342
    ) -> list[RequestOutput]:
343
344
345
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
346
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
347
348
    def generate(
        self,
349
        prompts: Optional[list[str]] = None,
350
        sampling_params: Optional[Union[SamplingParams,
351
                                        list[SamplingParams]]] = None,
352
        *,
353
        prompt_token_ids: list[list[int]],
354
        use_tqdm: bool = True,
355
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
356
357
358
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
359
    ) -> list[RequestOutput]:
360
361
362
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
363
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
364
365
366
367
    def generate(
        self,
        prompts: None,
        sampling_params: None,
368
        prompt_token_ids: Union[list[int], list[list[int]]],
369
        use_tqdm: bool = True,
370
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
371
372
373
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
374
    ) -> list[RequestOutput]:
375
376
        ...

nunjunj's avatar
nunjunj committed
377
378
379
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
380
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
381
    )
382
383
    def generate(
        self,
384
        prompts: Union[Union[PromptType, Sequence[PromptType]],
385
                       Optional[Union[str, list[str]]]] = None,
386
387
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
388
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
389
        use_tqdm: bool = True,
390
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
391
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
392
        guided_options_request: Optional[Union[LLMGuidedOptions,
393
                                               GuidedDecodingRequest]] = None,
394
395
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
396
397
        """Generates the completions for the input prompts.

398
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
399
400
401
402
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
403
404
405
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
406
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
407
408
409
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
410
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
411
            use_tqdm: Whether to use tqdm to display the progress bar.
412
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
413
            prompt_adapter_request: Prompt Adapter request to use for
414
                generation, if any.
415
416
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
417
418

        Returns:
nunjunj's avatar
nunjunj committed
419
            A list of ``RequestOutput`` objects containing the
420
            generated completions in the same order as the input prompts.
421
422
423
424
425

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
426
        """
427
        runner_type = self.llm_engine.model_config.runner_type
428
        if runner_type not in ["generate", "transcription"]:
429
            messages = [
430
                "LLM.generate() is only supported for (conditional) generation "
431
432
433
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

434
435
436
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
437
                messages.append(
438
439
440
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
441
442

            raise ValueError(" ".join(messages))
443

444
        if prompt_token_ids is not None:
445
            parsed_prompts = self._convert_v1_inputs(
446
                prompts=cast(Optional[Union[str, list[str]]], prompts),
447
448
449
                prompt_token_ids=prompt_token_ids,
            )
        else:
450
451
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
452

453
454
455
456
457
458
459
460
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

461
462
        if sampling_params is None:
            # Use default sampling params.
463
            sampling_params = self.get_default_sampling_params()
464

465
        self._validate_and_add_requests(
466
            prompts=parsed_prompts,
467
468
            params=sampling_params,
            lora_request=lora_request,
469
            prompt_adapter_request=prompt_adapter_request,
470
471
            guided_options=guided_options_request,
            priority=priority)
472

473
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
474
        return self.engine_class.validate_outputs(outputs, RequestOutput)
475

476
    def collective_rpc(self,
477
                       method: Union[str, Callable[..., _R]],
478
                       timeout: Optional[float] = None,
479
480
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
503
504

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
505
506

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
507
        """
508
509
        Run a function directly on the model inside each worker,
        returning the result for each of them.
510
        """
511
512
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
513

514
515
    def beam_search(
        self,
516
        prompts: list[Union[TokensPrompt, TextPrompt]],
517
        params: BeamSearchParams,
518
    ) -> list[BeamSearchOutput]:
519
520
521
522
523
524
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
525
            params: The beam search parameters.
526
        """
527
528
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
529
530
531
532
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
533
534
535
536
537
538
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
539

540
541
542
543
544
545
546
547
548
549
550
551
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
552

553
554
555
556
557
558
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
559
                                            temperature=temperature)
560
        instances: list[BeamSearchInstance] = []
561
562

        for prompt in prompts:
563
564
565
566
567
568
569
570
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

571
572
573
574
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
575
576
            instances.append(
                BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
577
578

        for _ in range(max_tokens):
579
            all_beams: list[BeamSearchSequence] = list(
580
581
582
583
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
584
            instance_start_and_end: list[tuple[int, int]] = list(
585
586
587
588
589
590
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
591
                create_tokens_prompt_from_beam(beam) for beam in all_beams
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
615
                                logprobs=current_beam.logprobs + [logprobs],
616
                                cum_logprob=current_beam.cum_logprob +
617
618
619
620
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
621
622
623
624
625
626
627

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
628
                                      key=sort_beams_key,
629
630
631
632
633
634
635
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
636
                                      key=sort_beams_key,
637
638
639
640
641
642
643
644
645
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
646
647
    def chat(
        self,
648
649
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
650
        sampling_params: Optional[Union[SamplingParams,
651
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
652
653
654
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
655
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
656
        add_generation_prompt: bool = True,
657
        continue_final_message: bool = False,
658
        tools: Optional[list[dict[str, Any]]] = None,
659
        chat_template_kwargs: Optional[dict[str, Any]] = None,
660
661
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
662
        """
663
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
664

665
666
667
668
669
670
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
671
672

        Args:
673
674
675
676
677
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
678
679
680
681
682
683
684
685
686
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
687
688
689
690
691
692
693
694
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

695
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
696
                to each message.
697
            continue_final_message: If True, continues the final message in
698
699
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
700
701
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
702
703
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
704
705
706
707
708

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
709
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
710

711
712
        # Handle multi and single conversations
        if is_list_of(messages, list):
713
714
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
715
                                    messages)
716
        else:
717
            # messages is list[...]
718
            list_of_messages = [
719
                cast(list[ChatCompletionMessageParam], messages)
720
            ]
721

722
        tokenizer = self.get_tokenizer(lora_request)
723
724
725
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
726
            tools,
727
728
            chat_template_content_format,
            tokenizer,
729
            trust_remote_code=model_config.trust_remote_code,
730
731
        )

732
733
734
735
736
737
738
739
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

740
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
741
742

        for msgs in list_of_messages:
743
744
745
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
746
            conversation, mm_data = parse_chat_messages(
747
748
749
750
751
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
752
753

            if isinstance(tokenizer, MistralTokenizer):
754
                prompt_token_ids = apply_mistral_chat_template(
755
756
                    tokenizer,
                    messages=msgs,
757
                    **_chat_template_kwargs,
758
759
                )
            else:
760
                prompt_str = apply_hf_chat_template(
761
                    tokenizer,
762
                    trust_remote_code=model_config.trust_remote_code,
763
                    conversation=conversation,
764
                    **_chat_template_kwargs,
765
                )
766
767
768
769
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
770

771
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
772
773
774
775

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

776
777
778
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

779
            prompts.append(prompt)
780

nunjunj's avatar
nunjunj committed
781
        return self.generate(
782
            prompts,
783
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
784
785
786
787
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

788
789
790
791
792
793
794
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
795
        *,
796
        use_tqdm: bool = True,
797
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
798
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
799
    ) -> list[PoolingRequestOutput]:
800
801
        ...

802
    @overload  # LEGACY: single (prompt + optional token ids)
803
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
804
805
806
807
808
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
809
        prompt_token_ids: Optional[list[int]] = None,
810
        use_tqdm: bool = True,
811
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
812
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
813
    ) -> list[PoolingRequestOutput]:
814
        ...
815

816
    @overload  # LEGACY: multi (prompt + optional token ids)
817
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
818
819
    def encode(
        self,
820
        prompts: list[str],
821
        pooling_params: Optional[Union[PoolingParams,
822
                                       Sequence[PoolingParams]]] = None,
823
        prompt_token_ids: Optional[list[list[int]]] = None,
824
        use_tqdm: bool = True,
825
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
826
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
827
    ) -> list[PoolingRequestOutput]:
828
829
830
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
831
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
832
833
834
835
836
837
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
838
        prompt_token_ids: list[int],
839
        use_tqdm: bool = True,
840
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
841
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
842
    ) -> list[PoolingRequestOutput]:
843
844
845
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
846
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
847
848
    def encode(
        self,
849
        prompts: Optional[list[str]] = None,
850
851
852
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
853
        prompt_token_ids: list[list[int]],
854
        use_tqdm: bool = True,
855
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
856
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
857
    ) -> list[PoolingRequestOutput]:
858
859
860
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
861
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
862
863
864
865
    def encode(
        self,
        prompts: None,
        pooling_params: None,
866
        prompt_token_ids: Union[list[int], list[list[int]]],
867
        use_tqdm: bool = True,
868
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
869
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
870
    ) -> list[PoolingRequestOutput]:
871
872
        ...

nunjunj's avatar
nunjunj committed
873
874
875
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
876
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
877
    )
878
879
    def encode(
        self,
880
        prompts: Union[Union[PromptType, Sequence[PromptType]],
881
                       Optional[Union[str, list[str]]]] = None,
882
883
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
884
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
885
        use_tqdm: bool = True,
886
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
887
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
888
    ) -> list[PoolingRequestOutput]:
889
890
        """Apply pooling to the hidden states corresponding to the input
        prompts.
891

892
        This class automatically batches the given prompts, considering
893
894
895
896
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
897
898
899
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
900
901
902
903
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
904
            prompt_adapter_request: Prompt Adapter request to use for
905
                generation, if any.
906
907

        Returns:
908
            A list of ``PoolingRequestOutput`` objects containing the
909
            pooled hidden states in the same order as the input prompts.
910
911
912
913
914

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
915
        """
916
917
918
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
919

920
921
922
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
923
                messages.append(
924
925
926
927
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
928
929

            raise ValueError(" ".join(messages))
930

931
        if prompt_token_ids is not None:
932
            parsed_prompts = self._convert_v1_inputs(
933
                prompts=cast(Optional[Union[str, list[str]]], prompts),
934
935
936
                prompt_token_ids=prompt_token_ids,
            )
        else:
937
938
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
939

940
941
942
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
943
944
945
946
947
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
948

949
        self._validate_and_add_requests(
950
            prompts=parsed_prompts,
951
952
            params=pooling_params,
            lora_request=lora_request,
953
            prompt_adapter_request=prompt_adapter_request,
954
955
        )

956
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
957
        return self.engine_class.validate_outputs(outputs,
958
                                                  PoolingRequestOutput)
959

960
961
962
963
964
965
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
966
967
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
968
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
969
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
970
    ) -> list[EmbeddingRequestOutput]:
971
972
973
974
975
976
977
978
979
980
981
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
982
983
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
999
                            pooling_params=pooling_params,
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1011
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1012
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1013
    ) -> list[ClassificationRequestOutput]:
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1045
1046
1047
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1048
1049
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1050
1051
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1052
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1053
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1054
    ) -> list[ScoringRequestOutput]:
1055

1056
        encoded_output: list[PoolingRequestOutput] = self.encode(
1057
1058
1059
1060
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1061

1062
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1063
            0:len(text_1)]
1064
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1065
            len(text_1):]
1066
1067
1068
1069

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1070
1071
1072
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1073
1074
1075
1076
1077
1078
1079

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1080
        tokenizer: AnyTokenizer,
1081
1082
        text_1: list[str],
        text_2: list[str],
1083
1084
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1085
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1086
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1087
    ) -> list[ScoringRequestOutput]:
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1100
        tokenization_kwargs: dict[str, Any] = {}
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1129
1130
1131
1132
1133
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1134
        *,
1135
1136
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1137
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1138
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1139
    ) -> list[ScoringRequestOutput]:
1140
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1141

1142
1143
1144
1145
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1146
1147
1148
1149
1150
1151
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1152
                case it has to have the same length as the ``text_2`` list
1153
1154
1155
1156
1157
1158
1159
1160
1161
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1162
            A list of ``ScoringRequestOutput`` objects containing the
1163
1164
            generated scores in the same order as the input prompts.
        """
1165
1166
1167
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1168

1169
1170
1171
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1172
                messages.append(
1173
1174
1175
1176
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1177
1178
1179

            raise ValueError(" ".join(messages))

1180
        if self.llm_engine.model_config.task not in ("embed", "score"):
1181
            raise ValueError(
1182
                "Score API is only enabled for `--task embed or --task score`")
1183
1184
1185
1186

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1187
1188
        tokenizer = self.llm_engine.get_tokenizer()

1189
1190
1191
1192
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1193
                                     "supported for scoring")
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1205
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1206
1207
1208
1209

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1210
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1211

1212
        _validate_score_input_lens(input_text_1, input_text_2)
1213

1214
        if self.llm_engine.model_config.is_cross_encoder:
1215
1216
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1217
1218
1219
1220
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1221
1222
1223
1224
1225
1226
1227
1228
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1229

1230
1231
1232
1233
1234
1235
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1236
1237
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1238

1239
1240
1241
1242
1243
1244
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1257
        """
1258
        self.reset_prefix_cache()
1259
1260
        self.llm_engine.sleep(level=level)

1261
    def wake_up(self, tags: Optional[list[str]] = None):
1262
1263
        """
        Wake up the engine from sleep mode. See the :meth:`sleep` method
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
                ("weights", "kv_cache",). If None, all memory is reallocated.
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1274

1275
1276
    # LEGACY
    def _convert_v1_inputs(
1277
        self,
1278
1279
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1280
1281
    ):
        # skip_tokenizer_init is now checked in engine
1282

1283
1284
1285
1286
1287
1288
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1289

1290
        num_requests = None
1291
1292
        if prompts is not None:
            num_requests = len(prompts)
1293
1294
1295
1296
1297
1298
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1299
            num_requests = len(prompt_token_ids)
1300
1301
1302
1303
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1304
        parsed_prompts: list[PromptType] = []
1305
        for i in range(num_requests):
1306
            item: PromptType
1307

1308
            if prompts is not None:
1309
1310
1311
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1312
            else:
1313
                raise AssertionError
1314

1315
            parsed_prompts.append(item)
1316

1317
        return parsed_prompts
1318
1319
1320

    def _validate_and_add_requests(
        self,
1321
        prompts: Union[PromptType, Sequence[PromptType]],
1322
1323
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1324
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1325
        prompt_adapter_request: Optional[PromptAdapterRequest],
1326
        guided_options: Optional[GuidedDecodingRequest] = None,
1327
        priority: Optional[list[int]] = None,
1328
    ) -> None:
1329
1330
1331
1332
1333
1334
1335
1336
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1337
        if isinstance(prompts, (str, dict)):
1338
            # Convert a single prompt to a list.
1339
            prompts = [prompts]
1340

1341
        num_requests = len(prompts)
1342
1343
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1344
                             "must be the same.")
1345
1346
1347
1348
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1349

1350
1351
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1352
                self._add_guided_params(sp, guided_options)
1353
1354
1355

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1356

Zhuohan Li's avatar
Zhuohan Li committed
1357
        # Add requests to the engine.
1358
        for i, prompt in enumerate(prompts):
1359
            self._add_request(
1360
                prompt,
1361
                params[i] if isinstance(params, Sequence) else params,
1362
1363
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1364
                prompt_adapter_request=prompt_adapter_request,
1365
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1366
            )
1367

1368
    def _add_request(
nunjunj's avatar
nunjunj committed
1369
        self,
1370
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1371
        params: Union[SamplingParams, PoolingParams],
1372
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1373
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1374
        priority: int = 0,
1375
1376
    ) -> None:
        request_id = str(next(self.request_counter))
1377
1378
        self.llm_engine.add_request(
            request_id,
1379
            prompt,
1380
1381
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1382
            prompt_adapter_request=prompt_adapter_request,
1383
            priority=priority,
nunjunj's avatar
nunjunj committed
1384
        )
1385

1386
    def _add_guided_params(
1387
1388
1389
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1390
1391
1392
1393
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1394
            raise ValueError("Cannot set both guided_options_request and "
1395
1396
1397
1398
1399
1400
1401
1402
1403
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1404
1405
1406
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1407
1408
        return params

1409
    def _run_engine(
1410
            self, *, use_tqdm: bool
1411
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1412
1413
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1414
            num_requests = self.llm_engine.get_num_unfinished_requests()
1415
1416
1417
1418
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1419
1420
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1421
            )
1422

Zhuohan Li's avatar
Zhuohan Li committed
1423
        # Run the engine.
1424
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1425
1426
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1427
1428
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1429
            for output in step_outputs:
1430
                if output.finished:
1431
1432
                    outputs.append(output)
                    if use_tqdm:
1433
1434
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1435
                            n = len(output.outputs)
1436
                            assert output.prompt_token_ids is not None
1437
                            total_in_toks += len(output.prompt_token_ids) * n
1438
1439
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1440
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1441
1442
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1443
1444
1445
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1446
                            pbar.update(n)
1447
1448
                        else:
                            pbar.update(1)
1449

1450
1451
        if use_tqdm:
            pbar.close()
1452
1453
1454
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1455
        return sorted(outputs, key=lambda x: int(x.request_id))