"docs/vscode:/vscode.git/clone" did not exist on "a3bf8d4a2b3b84c8580e0984e8463937e25b1b99"
llm.py 62.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
from vllm.config import CompilationConfig
17
18
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
19
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
20
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
21
                                         ChatTemplateContentFormatOption,
22
23
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
24
25
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
26
27
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
28
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
29
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
30
from vllm.logger import init_logger
31
from vllm.lora.request import LoRARequest
32
33
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
34
35
36
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
37
from vllm.pooling_params import PoolingParams
38
from vllm.prompt_adapter.request import PromptAdapterRequest
39
40
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
41
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
42
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
43
from vllm.usage.usage_lib import UsageContext
44
45
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
46

47
48
logger = init_logger(__name__)

49
50
_R = TypeVar("_R", default=Any)

51
52

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55
56
57
58
59
60
61
62
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
63
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
64
65
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
66
67
68
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
69
70
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
71
72
73
74
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
82
        quantization: The method used to quantize the model weights. Currently,
83
            we support "awq", "gptq", and "fp8" (experimental).
84
85
86
87
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
88
89
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
90
91
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
92
93
94
95
96
97
98
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
99
100
101
102
103
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
104
105
106
107
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
108
109
110
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
111
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
112
            When a sequence has context length larger than this, we fall back
113
114
115
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
116
117
118
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
119
120
121
        hf_token: The token to use as HTTP bearer authorization for remote files
            . If `True`, will use the token generated when running 
            `huggingface-cli login` (stored in `~/.huggingface`).
122
123
124
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
125
126
127
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
128
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
129
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
130

131
132
133
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
134
    """
135

136
    DEPRECATE_LEGACY: ClassVar[bool] = True
137
138
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

139
140
141
142
143
144
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

145
146
147
148
149
150
151
152
153
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

154
155
156
157
158
159
160
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
161
162
163
    def __init__(
        self,
        model: str,
164
        tokenizer: Optional[str] = None,
165
        tokenizer_mode: str = "auto",
166
        skip_tokenizer_init: bool = False,
167
        trust_remote_code: bool = False,
168
        allowed_local_media_path: str = "",
169
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
170
        dtype: str = "auto",
171
        quantization: Optional[str] = None,
172
        revision: Optional[str] = None,
173
        tokenizer_revision: Optional[str] = None,
174
        seed: Optional[int] = None,
175
        gpu_memory_utilization: float = 0.9,
176
        swap_space: float = 4,
177
        cpu_offload_gb: float = 0,
178
        enforce_eager: Optional[bool] = None,
179
        max_seq_len_to_capture: int = 8192,
180
        disable_custom_all_reduce: bool = False,
181
        disable_async_output_proc: bool = False,
182
        hf_token: Optional[Union[bool, str]] = None,
183
        hf_overrides: Optional[HfOverrides] = None,
184
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
185
186
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
187
        override_pooler_config: Optional[PoolerConfig] = None,
188
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
189
190
        **kwargs,
    ) -> None:
191
192
193
194
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
195
        it defaults to False.
196
197
        '''

198
199
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
200

201
202
203
204
205
206
207
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

208
        if compilation_config is not None:
209
            if isinstance(compilation_config, (int, dict)):
210
211
212
213
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
214
215
216
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
217
        engine_args = EngineArgs(
218
            model=model,
219
            task=task,
220
            tokenizer=tokenizer,
221
            tokenizer_mode=tokenizer_mode,
222
            skip_tokenizer_init=skip_tokenizer_init,
223
            trust_remote_code=trust_remote_code,
224
            allowed_local_media_path=allowed_local_media_path,
225
226
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
227
            quantization=quantization,
228
            revision=revision,
229
            tokenizer_revision=tokenizer_revision,
230
231
232
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
233
            cpu_offload_gb=cpu_offload_gb,
234
            enforce_eager=enforce_eager,
235
            max_seq_len_to_capture=max_seq_len_to_capture,
236
            disable_custom_all_reduce=disable_custom_all_reduce,
237
            disable_async_output_proc=disable_async_output_proc,
238
            hf_token=hf_token,
239
            hf_overrides=hf_overrides,
240
            mm_processor_kwargs=mm_processor_kwargs,
241
            override_pooler_config=override_pooler_config,
242
            compilation_config=compilation_config_instance,
243
244
            **kwargs,
        )
245
246
247
248
249

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
250

251
        self.request_counter = Counter()
252
        self.default_sampling_params: Union[dict[str, Any], None] = None
253

254
    def get_tokenizer(self) -> AnyTokenizer:
255
        return self.llm_engine.get_tokenizer_group().tokenizer
256
257

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
258
        tokenizer_group = self.llm_engine.get_tokenizer_group()
259

260
261
262
263
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
264
            tokenizer_group.tokenizer = tokenizer
265
        else:
266
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
267

268
    def get_default_sampling_params(self) -> SamplingParams:
269
270
271
272
273
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
274
275
        return SamplingParams()

276
277
278
279
280
281
282
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
283
        *,
284
        use_tqdm: bool = True,
285
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
286
287
288
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
289
    ) -> list[RequestOutput]:
290
291
        ...

292
    @overload  # LEGACY: single (prompt + optional token ids)
293
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
294
295
296
297
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
298
299
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
300
        use_tqdm: bool = True,
301
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
302
303
304
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
305
    ) -> list[RequestOutput]:
306
307
308
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
309
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
310
311
    def generate(
        self,
312
        prompts: list[str],
313
        sampling_params: Optional[Union[SamplingParams,
314
315
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
316
        use_tqdm: bool = True,
317
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
318
319
320
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
321
    ) -> list[RequestOutput]:
322
323
324
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
325
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
326
327
328
329
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
330
                                        list[SamplingParams]]] = None,
331
        *,
332
        prompt_token_ids: list[int],
333
        use_tqdm: bool = True,
334
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
335
336
337
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
338
    ) -> list[RequestOutput]:
339
340
341
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
342
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
343
344
    def generate(
        self,
345
        prompts: Optional[list[str]] = None,
346
        sampling_params: Optional[Union[SamplingParams,
347
                                        list[SamplingParams]]] = None,
348
        *,
349
        prompt_token_ids: list[list[int]],
350
        use_tqdm: bool = True,
351
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
352
353
354
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
355
    ) -> list[RequestOutput]:
356
357
358
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
359
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
360
361
362
363
    def generate(
        self,
        prompts: None,
        sampling_params: None,
364
        prompt_token_ids: Union[list[int], list[list[int]]],
365
        use_tqdm: bool = True,
366
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
367
368
369
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
370
    ) -> list[RequestOutput]:
371
372
        ...

nunjunj's avatar
nunjunj committed
373
374
375
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
376
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
377
    )
378
379
    def generate(
        self,
380
        prompts: Union[Union[PromptType, Sequence[PromptType]],
381
                       Optional[Union[str, list[str]]]] = None,
382
383
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
384
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
385
        use_tqdm: bool = True,
386
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
387
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
388
        guided_options_request: Optional[Union[LLMGuidedOptions,
389
                                               GuidedDecodingRequest]] = None,
390
391
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
392
393
        """Generates the completions for the input prompts.

394
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
395
396
397
398
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
399
400
401
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
402
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
403
404
405
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
406
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
407
            use_tqdm: Whether to use tqdm to display the progress bar.
408
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
409
            prompt_adapter_request: Prompt Adapter request to use for
410
                generation, if any.
411
412
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
413
414

        Returns:
nunjunj's avatar
nunjunj committed
415
            A list of ``RequestOutput`` objects containing the
416
            generated completions in the same order as the input prompts.
417
418
419
420
421

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
422
        """
423
        runner_type = self.llm_engine.model_config.runner_type
424
        if runner_type not in ["generate", "transcription"]:
425
            messages = [
426
                "LLM.generate() is only supported for (conditional) generation "
427
428
429
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

430
431
432
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
433
                messages.append(
434
435
436
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
437
438

            raise ValueError(" ".join(messages))
439

440
        if prompt_token_ids is not None:
441
            parsed_prompts = self._convert_v1_inputs(
442
                prompts=cast(Optional[Union[str, list[str]]], prompts),
443
444
445
                prompt_token_ids=prompt_token_ids,
            )
        else:
446
447
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
448

449
450
451
452
453
454
455
456
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

457
458
        if sampling_params is None:
            # Use default sampling params.
459
            sampling_params = self.get_default_sampling_params()
460

461
        self._validate_and_add_requests(
462
            prompts=parsed_prompts,
463
464
            params=sampling_params,
            lora_request=lora_request,
465
            prompt_adapter_request=prompt_adapter_request,
466
467
            guided_options=guided_options_request,
            priority=priority)
468

469
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
470
        return self.engine_class.validate_outputs(outputs, RequestOutput)
471

472
    def collective_rpc(self,
473
                       method: Union[str, Callable[..., _R]],
474
                       timeout: Optional[float] = None,
475
476
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
499
500

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
501
502

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
503
        """
504
505
        Run a function directly on the model inside each worker,
        returning the result for each of them.
506
        """
507
508
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
509

510
511
    def beam_search(
        self,
512
        prompts: list[Union[TokensPrompt, TextPrompt]],
513
        params: BeamSearchParams,
514
    ) -> list[BeamSearchOutput]:
515
516
517
518
519
520
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
521
522
            params: The beam search parameters.

523
524
525
526
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

527
528
529
530
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
531
532
533
534
535
536
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
537

538
539
540
541
542
543
544
545
546
547
548
549
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
550

551
552
553
554
555
556
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
557
                                            temperature=temperature)
558
        instances: list[BeamSearchInstance] = []
559
560

        for prompt in prompts:
561
562
563
564
565
566
567
568
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

569
570
571
572
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
573
574
            instances.append(
                BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
575
576

        for _ in range(max_tokens):
577
            all_beams: list[BeamSearchSequence] = list(
578
579
580
581
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
582
            instance_start_and_end: list[tuple[int, int]] = list(
583
584
585
586
587
588
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
589
                create_tokens_prompt_from_beam(beam) for beam in all_beams
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
613
                                logprobs=current_beam.logprobs + [logprobs],
614
                                cum_logprob=current_beam.cum_logprob +
615
616
617
618
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
619
620
621
622
623
624
625

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
626
                                      key=sort_beams_key,
627
628
629
630
631
632
633
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
634
                                      key=sort_beams_key,
635
636
637
638
639
640
641
642
643
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
644
645
    def chat(
        self,
646
647
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
648
        sampling_params: Optional[Union[SamplingParams,
649
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
650
651
652
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
653
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
654
        add_generation_prompt: bool = True,
655
        continue_final_message: bool = False,
656
657
658
        tools: Optional[list[dict[str, Any]]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
659
        """
660
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
661

662
663
664
665
666
667
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
668
669

        Args:
670
671
672
673
674
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
675
676
677
678
679
680
681
682
683
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
684
685
686
687
688
689
690
691
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

692
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
693
                to each message.
694
            continue_final_message: If True, continues the final message in
695
696
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
697
698
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
699
700
701
702
703

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
704
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
705

706
707
        # Handle multi and single conversations
        if is_list_of(messages, list):
708
709
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
710
                                    messages)
711
        else:
712
            # messages is list[...]
713
            list_of_messages = [
714
                cast(list[ChatCompletionMessageParam], messages)
715
            ]
716

717
718
719
720
        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
721
            tools,
722
723
            chat_template_content_format,
            tokenizer,
724
            trust_remote_code=model_config.trust_remote_code,
725
726
        )

727
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
728
729

        for msgs in list_of_messages:
730
731
732
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
733
            conversation, mm_data = parse_chat_messages(
734
735
736
737
738
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
739

740
            prompt_data: Union[str, list[int]]
741
742
743
744
745
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
746
                    tools=tools,
747
                    add_generation_prompt=add_generation_prompt,
748
                    continue_final_message=continue_final_message,
749
750
751
752
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
753
                    trust_remote_code=model_config.trust_remote_code,
754
755
                    conversation=conversation,
                    chat_template=chat_template,
756
                    tools=tools,
757
                    add_generation_prompt=add_generation_prompt,
758
                    continue_final_message=continue_final_message,
759
760
761
762
763
764
765
766
767
768
769
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

770
771
772
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

773
            prompts.append(prompt)
774

nunjunj's avatar
nunjunj committed
775
        return self.generate(
776
            prompts,
777
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
778
779
780
781
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

782
783
784
785
786
787
788
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
789
        *,
790
        use_tqdm: bool = True,
791
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
792
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
793
    ) -> list[PoolingRequestOutput]:
794
795
        ...

796
    @overload  # LEGACY: single (prompt + optional token ids)
797
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
798
799
800
801
802
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
803
        prompt_token_ids: Optional[list[int]] = None,
804
        use_tqdm: bool = True,
805
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
806
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
807
    ) -> list[PoolingRequestOutput]:
808
        ...
809

810
    @overload  # LEGACY: multi (prompt + optional token ids)
811
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
812
813
    def encode(
        self,
814
        prompts: list[str],
815
        pooling_params: Optional[Union[PoolingParams,
816
                                       Sequence[PoolingParams]]] = None,
817
        prompt_token_ids: Optional[list[list[int]]] = None,
818
        use_tqdm: bool = True,
819
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
820
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
821
    ) -> list[PoolingRequestOutput]:
822
823
824
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
825
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
826
827
828
829
830
831
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
832
        prompt_token_ids: list[int],
833
        use_tqdm: bool = True,
834
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
835
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
836
    ) -> list[PoolingRequestOutput]:
837
838
839
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
840
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
841
842
    def encode(
        self,
843
        prompts: Optional[list[str]] = None,
844
845
846
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
847
        prompt_token_ids: list[list[int]],
848
        use_tqdm: bool = True,
849
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
850
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
851
    ) -> list[PoolingRequestOutput]:
852
853
854
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
855
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
856
857
858
859
    def encode(
        self,
        prompts: None,
        pooling_params: None,
860
        prompt_token_ids: Union[list[int], list[list[int]]],
861
        use_tqdm: bool = True,
862
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
863
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
864
    ) -> list[PoolingRequestOutput]:
865
866
        ...

nunjunj's avatar
nunjunj committed
867
868
869
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
870
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
871
    )
872
873
    def encode(
        self,
874
        prompts: Union[Union[PromptType, Sequence[PromptType]],
875
                       Optional[Union[str, list[str]]]] = None,
876
877
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
878
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
879
        use_tqdm: bool = True,
880
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
881
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
882
    ) -> list[PoolingRequestOutput]:
883
884
        """Apply pooling to the hidden states corresponding to the input
        prompts.
885

886
        This class automatically batches the given prompts, considering
887
888
889
890
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
891
892
893
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
894
895
896
897
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
898
            prompt_adapter_request: Prompt Adapter request to use for
899
                generation, if any.
900
901

        Returns:
902
            A list of ``PoolingRequestOutput`` objects containing the
903
            pooled hidden states in the same order as the input prompts.
904
905
906
907
908

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
909
        """
910
911
912
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
913

914
915
916
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
917
                messages.append(
918
919
920
921
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
922
923

            raise ValueError(" ".join(messages))
924

925
        if prompt_token_ids is not None:
926
            parsed_prompts = self._convert_v1_inputs(
927
                prompts=cast(Optional[Union[str, list[str]]], prompts),
928
929
930
                prompt_token_ids=prompt_token_ids,
            )
        else:
931
932
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
933

934
935
936
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
937
938
939
940
941
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
942

943
        self._validate_and_add_requests(
944
            prompts=parsed_prompts,
945
946
            params=pooling_params,
            lora_request=lora_request,
947
            prompt_adapter_request=prompt_adapter_request,
948
949
        )

950
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
951
        return self.engine_class.validate_outputs(outputs,
952
                                                  PoolingRequestOutput)
953

954
955
956
957
958
959
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
960
961
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
962
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
963
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
964
    ) -> list[EmbeddingRequestOutput]:
965
966
967
968
969
970
971
972
973
974
975
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
976
977
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
993
                            pooling_params=pooling_params,
994
995
996
997
998
999
1000
1001
1002
1003
1004
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1005
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1006
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1007
    ) -> list[ClassificationRequestOutput]:
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1039
1040
1041
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1042
1043
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1044
1045
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1046
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1047
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1048
    ) -> list[ScoringRequestOutput]:
1049

1050
        encoded_output: list[PoolingRequestOutput] = self.encode(
1051
1052
1053
1054
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1055

1056
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1057
            0:len(text_1)]
1058
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1059
            len(text_1):]
1060
1061
1062
1063

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1064
        scores: list[PoolingRequestOutput] = []
1065

1066
1067
1068
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1069
1070
1071
1072
1073
1074
1075

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1076
        tokenizer: AnyTokenizer,
1077
1078
        text_1: list[str],
        text_2: list[str],
1079
1080
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1081
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1082
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1083
    ) -> list[ScoringRequestOutput]:
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1096
        tokenization_kwargs: dict[str, Any] = {}
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1125
1126
1127
1128
1129
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1130
        *,
1131
1132
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1133
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1134
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1135
    ) -> list[ScoringRequestOutput]:
1136
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1137

1138
1139
1140
1141
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1142
1143
1144
1145
1146
1147
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1148
                case it has to have the same length as the ``text_2`` list
1149
1150
1151
1152
1153
1154
1155
1156
1157
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1158
            A list of ``ScoringRequestOutput`` objects containing the
1159
1160
            generated scores in the same order as the input prompts.
        """
1161
1162
1163
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1164

1165
1166
1167
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1168
                messages.append(
1169
1170
1171
1172
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1173
1174
1175

            raise ValueError(" ".join(messages))

1176
        if self.llm_engine.model_config.task not in ("embed", "score"):
1177
            raise ValueError(
1178
                "Score API is only enabled for `--task embed or --task score`")
1179
1180
1181
1182

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1183
1184
        tokenizer = self.llm_engine.get_tokenizer()

1185
1186
1187
1188
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1189
                                     "supported for scoring")
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1201
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1202
1203
1204
1205

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1206
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1207

1208
        _validate_score_input_lens(input_text_1, input_text_2)
1209

1210
        if self.llm_engine.model_config.is_cross_encoder:
1211
1212
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1213
1214
1215
1216
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1217
1218
1219
1220
1221
1222
1223
1224
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1225

1226
1227
1228
1229
1230
1231
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1232
1233
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1234

1235
1236
1237
1238
1239
1240
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1253
        """
1254
        self.reset_prefix_cache()
1255
1256
        self.llm_engine.sleep(level=level)

1257
    def wake_up(self, tags: Optional[list[str]] = None):
1258
1259
        """
        Wake up the engine from sleep mode. See the :meth:`sleep` method
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
                ("weights", "kv_cache",). If None, all memory is reallocated.
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1270

1271
1272
    # LEGACY
    def _convert_v1_inputs(
1273
        self,
1274
1275
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1276
1277
    ):
        # skip_tokenizer_init is now checked in engine
1278

1279
1280
1281
1282
1283
1284
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1285

1286
        num_requests = None
1287
1288
        if prompts is not None:
            num_requests = len(prompts)
1289
1290
1291
1292
1293
1294
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1295
            num_requests = len(prompt_token_ids)
1296
1297
1298
1299
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1300
        parsed_prompts: list[PromptType] = []
1301
        for i in range(num_requests):
1302
            item: PromptType
1303

1304
            if prompts is not None:
1305
1306
1307
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1308
            else:
1309
                raise AssertionError
1310

1311
            parsed_prompts.append(item)
1312

1313
        return parsed_prompts
1314
1315
1316

    def _validate_and_add_requests(
        self,
1317
        prompts: Union[PromptType, Sequence[PromptType]],
1318
1319
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1320
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1321
        prompt_adapter_request: Optional[PromptAdapterRequest],
1322
        guided_options: Optional[GuidedDecodingRequest] = None,
1323
        priority: Optional[list[int]] = None,
1324
    ) -> None:
1325
1326
1327
1328
1329
1330
1331
1332
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1333
        if isinstance(prompts, (str, dict)):
1334
            # Convert a single prompt to a list.
1335
            prompts = [prompts]
1336

1337
        num_requests = len(prompts)
1338
1339
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1340
                             "must be the same.")
1341
1342
1343
1344
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1345

1346
1347
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1348
                self._add_guided_params(sp, guided_options)
1349
1350
1351

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1352

Zhuohan Li's avatar
Zhuohan Li committed
1353
        # Add requests to the engine.
1354
        for i, prompt in enumerate(prompts):
1355
            self._add_request(
1356
                prompt,
1357
                params[i] if isinstance(params, Sequence) else params,
1358
1359
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1360
                prompt_adapter_request=prompt_adapter_request,
1361
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1362
            )
1363

1364
    def _add_request(
nunjunj's avatar
nunjunj committed
1365
        self,
1366
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1367
        params: Union[SamplingParams, PoolingParams],
1368
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1369
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1370
        priority: int = 0,
1371
1372
    ) -> None:
        request_id = str(next(self.request_counter))
1373
1374
        self.llm_engine.add_request(
            request_id,
1375
            prompt,
1376
1377
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1378
            prompt_adapter_request=prompt_adapter_request,
1379
            priority=priority,
nunjunj's avatar
nunjunj committed
1380
        )
1381

1382
    def _add_guided_params(
1383
1384
1385
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1386
1387
1388
1389
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1390
            raise ValueError("Cannot set both guided_options_request and "
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
1401
1402
        return params

1403
    def _run_engine(
1404
            self, *, use_tqdm: bool
1405
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1406
1407
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1408
            num_requests = self.llm_engine.get_num_unfinished_requests()
1409
1410
1411
1412
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1413
1414
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1415
            )
1416

Zhuohan Li's avatar
Zhuohan Li committed
1417
        # Run the engine.
1418
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1419
1420
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1421
1422
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1423
            for output in step_outputs:
1424
                if output.finished:
1425
1426
                    outputs.append(output)
                    if use_tqdm:
1427
1428
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1429
                            n = len(output.outputs)
1430
                            assert output.prompt_token_ids is not None
1431
                            total_in_toks += len(output.prompt_token_ids) * n
1432
1433
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1434
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1435
1436
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1437
1438
1439
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1440
                            pbar.update(n)
1441
1442
                        else:
                            pbar.update(1)
1443

1444
1445
        if use_tqdm:
            pbar.close()
1446
1447
1448
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1449
        return sorted(outputs, key=lambda x: int(x.request_id))