llm.py 60.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from contextlib import contextmanager
6
7
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
                    Tuple, Type, Union, cast, overload)
8

9
import cloudpickle
10
import torch
11
import torch.nn as nn
12
from tqdm import tqdm
13
from typing_extensions import TypeVar, deprecated
14

15
from vllm import envs
16
17
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
18
from vllm.config import CompilationConfig
19
20
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
21
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
22
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
23
                                         ChatTemplateContentFormatOption,
24
25
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
26
27
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
28
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
29
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
30
from vllm.logger import init_logger
31
from vllm.lora.request import LoRARequest
32
33
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
34
35
36
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
37
from vllm.pooling_params import PoolingParams
38
from vllm.prompt_adapter.request import PromptAdapterRequest
39
40
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
41
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
42
43
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
44
from vllm.usage.usage_lib import UsageContext
45
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
46

47
48
logger = init_logger(__name__)

49
50
_R = TypeVar("_R", default=Any)

51
52

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55
56
57
58
59
60
61
62
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
63
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
64
65
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
66
67
68
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
69
70
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
71
72
73
74
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
82
        quantization: The method used to quantize the model weights. Currently,
83
            we support "awq", "gptq", and "fp8" (experimental).
84
85
86
87
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
88
89
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
90
91
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
92
93
94
95
96
97
98
99
100
101
102
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
103
104
105
106
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
107
108
109
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
110
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
111
            When a sequence has context length larger than this, we fall back
112
113
114
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
115
116
117
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
118
119
120
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
121
122
123
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
124
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
125
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
126

127
128
129
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
130
    """
131

132
    DEPRECATE_LEGACY: ClassVar[bool] = True
133
134
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

135
136
137
138
139
140
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

141
142
143
144
145
146
147
148
149
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

150
151
152
153
154
155
156
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
157
158
159
    def __init__(
        self,
        model: str,
160
        tokenizer: Optional[str] = None,
161
        tokenizer_mode: str = "auto",
162
        skip_tokenizer_init: bool = False,
163
        trust_remote_code: bool = False,
164
        allowed_local_media_path: str = "",
165
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
166
        dtype: str = "auto",
167
        quantization: Optional[str] = None,
168
        revision: Optional[str] = None,
169
        tokenizer_revision: Optional[str] = None,
170
171
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
172
        swap_space: float = 4,
173
        cpu_offload_gb: float = 0,
174
        enforce_eager: Optional[bool] = None,
175
        max_seq_len_to_capture: int = 8192,
176
        disable_custom_all_reduce: bool = False,
177
        disable_async_output_proc: bool = False,
178
        hf_overrides: Optional[HfOverrides] = None,
179
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
180
181
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
182
        override_pooler_config: Optional[PoolerConfig] = None,
183
        compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
184
185
        **kwargs,
    ) -> None:
186
187
188
189
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
190
        it defaults to False.
191
192
        '''

193
194
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
195

196
197
198
199
200
201
202
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

203
        if compilation_config is not None:
204
            if isinstance(compilation_config, (int, dict)):
205
206
207
208
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
209
210
211
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
212
        engine_args = EngineArgs(
213
            model=model,
214
            task=task,
215
            tokenizer=tokenizer,
216
            tokenizer_mode=tokenizer_mode,
217
            skip_tokenizer_init=skip_tokenizer_init,
218
            trust_remote_code=trust_remote_code,
219
            allowed_local_media_path=allowed_local_media_path,
220
221
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
222
            quantization=quantization,
223
            revision=revision,
224
            tokenizer_revision=tokenizer_revision,
225
226
227
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
228
            cpu_offload_gb=cpu_offload_gb,
229
            enforce_eager=enforce_eager,
230
            max_seq_len_to_capture=max_seq_len_to_capture,
231
            disable_custom_all_reduce=disable_custom_all_reduce,
232
            disable_async_output_proc=disable_async_output_proc,
233
            hf_overrides=hf_overrides,
234
            mm_processor_kwargs=mm_processor_kwargs,
235
            override_pooler_config=override_pooler_config,
236
            compilation_config=compilation_config_instance,
237
238
            **kwargs,
        )
Joe Runde's avatar
Joe Runde committed
239
240
241
242
        # Logic to switch between engines is done at runtime instead of import
        # to avoid import order issues
        self.engine_class = self.get_engine_class()
        self.llm_engine = self.engine_class.from_engine_args(
yhu422's avatar
yhu422 committed
243
            engine_args, usage_context=UsageContext.LLM_CLASS)
244

245
246
        self.request_counter = Counter()

Joe Runde's avatar
Joe Runde committed
247
248
249
250
251
252
253
254
    @staticmethod
    def get_engine_class() -> Type[LLMEngine]:
        if envs.VLLM_USE_V1:
            # Lazy import: the v1 package isn't distributed
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            return V1LLMEngine  # type: ignore
        return LLMEngine

255
256
257
258
259
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
260

261
262
263
264
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
265
            tokenizer_group.tokenizer = tokenizer
266
        else:
267
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
268

269
270
271
272
273
274
275
    def get_default_sampling_params(self) -> SamplingParams:
        diff_sampling_param = (
            self.llm_engine.model_config.get_diff_sampling_param())
        if diff_sampling_param:
            return SamplingParams.from_optional(**diff_sampling_param)
        return SamplingParams()

276
277
278
279
280
281
282
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
283
        *,
284
285
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
286
287
288
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
289
290
291
    ) -> List[RequestOutput]:
        ...

292
    @overload  # LEGACY: single (prompt + optional token ids)
293
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
294
295
296
297
298
299
300
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
301
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
302
303
304
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
305
306
307
308
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
309
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
310
311
    def generate(
        self,
312
        prompts: List[str],
313
314
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
315
        prompt_token_ids: Optional[List[List[int]]] = None,
316
        use_tqdm: bool = True,
317
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
318
319
320
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
321
322
323
324
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
325
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
326
327
328
329
330
331
332
333
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
334
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
335
336
337
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
338
339
340
341
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
342
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
343
344
345
346
347
348
349
350
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
351
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
352
353
354
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
355
356
357
358
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
359
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
360
361
362
363
364
365
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
366
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
367
368
369
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
370
371
372
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
373
374
375
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
376
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
377
    )
378
379
    def generate(
        self,
380
        prompts: Union[Union[PromptType, Sequence[PromptType]],
381
382
383
384
385
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
386
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
387
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
388
        guided_options_request: Optional[Union[LLMGuidedOptions,
389
390
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
391
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
392
393
        """Generates the completions for the input prompts.

394
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
395
396
397
398
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
399
400
401
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
402
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
403
404
405
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
406
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
407
            use_tqdm: Whether to use tqdm to display the progress bar.
408
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
409
            prompt_adapter_request: Prompt Adapter request to use for
410
                generation, if any.
411
412
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
413
414

        Returns:
nunjunj's avatar
nunjunj committed
415
            A list of ``RequestOutput`` objects containing the
416
            generated completions in the same order as the input prompts.
417
418
419
420
421

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
422
        """
423
424
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "generate":
425
            messages = [
426
                "LLM.generate() is only supported for (conditional) generation "
427
428
429
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

430
431
432
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
433
                messages.append(
434
435
436
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
437
438

            raise ValueError(" ".join(messages))
439

440
        if prompt_token_ids is not None:
441
            parsed_prompts = self._convert_v1_inputs(
442
443
444
445
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
446
447
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
448

449
450
451
452
453
454
455
456
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

457
458
        if sampling_params is None:
            # Use default sampling params.
459
            sampling_params = self.get_default_sampling_params()
460

461
        self._validate_and_add_requests(
462
            prompts=parsed_prompts,
463
464
            params=sampling_params,
            lora_request=lora_request,
465
            prompt_adapter_request=prompt_adapter_request,
466
467
            guided_options=guided_options_request,
            priority=priority)
468

469
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
470
        return self.engine_class.validate_outputs(outputs, RequestOutput)
471

472
    def collective_rpc(self,
473
                       method: Union[str, Callable[..., _R]],
474
475
                       timeout: Optional[float] = None,
                       args: Tuple = (),
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
                       kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
        executor = self.llm_engine.model_executor
        return executor.collective_rpc(method, timeout, args, kwargs)

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
503
        """
504
505
        Run a function directly on the model inside each worker,
        returning the result for each of them.
506
        """
507
508
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
509

510
511
    def beam_search(
        self,
512
        prompts: List[Union[TokensPrompt, TextPrompt]],
513
        params: BeamSearchParams,
514
515
516
517
518
519
520
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
521
522
            params: The beam search parameters.

523
524
525
526
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

527
528
529
530
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
531
532
533
534
535
536
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
537

538
539
540
541
542
543
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
544
                                            temperature=temperature)
545
546
547
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
548
549
550
551
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
592
                                logprobs=current_beam.logprobs + [logprobs],
593
594
595
596
597
598
599
600
601
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
602
                                      key=sort_beams_key,
603
604
605
606
607
608
609
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
610
                                      key=sort_beams_key,
611
612
613
614
615
616
617
618
619
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
620
621
    def chat(
        self,
622
623
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
624
625
626
627
628
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
629
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
630
        add_generation_prompt: bool = True,
631
        continue_final_message: bool = False,
632
        tools: Optional[List[Dict[str, Any]]] = None,
633
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
nunjunj's avatar
nunjunj committed
634
635
    ) -> List[RequestOutput]:
        """
636
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
637

638
639
640
641
642
643
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
644
645

        Args:
646
647
648
649
650
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
651
652
653
654
655
656
657
658
659
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
660
661
662
663
664
665
666
667
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

668
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
669
                to each message.
670
            continue_final_message: If True, continues the final message in
671
672
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
673
674
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
675
676
677
678
679

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
680
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
681

682
683
684
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
685
686
            list_of_messages = cast(List[List[ChatCompletionMessageParam]],
                                    messages)
687
        else:
688
            # messages is List[...]
689
690
691
            list_of_messages = [
                cast(List[ChatCompletionMessageParam], messages)
            ]
692

693
694
695
696
697
698
699
700
        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
            chat_template_content_format,
            tokenizer,
        )

701
702
703
        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
704
705
706
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
707
            conversation, mm_data = parse_chat_messages(
708
709
710
711
712
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
713
714
715
716
717
718
719
720

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
721
                    continue_final_message=continue_final_message,
722
723
724
725
726
727
728
729
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
730
                    continue_final_message=continue_final_message,
731
732
733
734
735
736
737
738
739
740
741
742
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

743
744
745
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

746
            prompts.append(prompt)
747

nunjunj's avatar
nunjunj committed
748
        return self.generate(
749
            prompts,
750
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
751
752
753
754
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

755
756
757
758
759
760
761
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
762
        *,
763
764
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
765
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
766
767
768
    ) -> List[PoolingRequestOutput]:
        ...

769
    @overload  # LEGACY: single (prompt + optional token ids)
770
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
771
772
773
774
775
776
777
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
778
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
779
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
780
    ) -> List[PoolingRequestOutput]:
781
        ...
782

783
    @overload  # LEGACY: multi (prompt + optional token ids)
784
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
785
786
    def encode(
        self,
787
        prompts: List[str],
788
        pooling_params: Optional[Union[PoolingParams,
789
                                       Sequence[PoolingParams]]] = None,
790
791
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
792
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
793
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
794
    ) -> List[PoolingRequestOutput]:
795
796
797
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
798
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
799
800
801
802
803
804
805
806
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
807
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
808
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
809
    ) -> List[PoolingRequestOutput]:
810
811
812
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
813
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
814
815
816
817
818
819
820
821
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
822
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
823
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
824
    ) -> List[PoolingRequestOutput]:
825
826
827
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
828
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
829
830
831
832
833
834
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
835
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
836
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
837
    ) -> List[PoolingRequestOutput]:
838
839
        ...

nunjunj's avatar
nunjunj committed
840
841
842
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
843
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
844
    )
845
846
    def encode(
        self,
847
        prompts: Union[Union[PromptType, Sequence[PromptType]],
848
849
850
851
852
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
853
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
854
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
855
    ) -> List[PoolingRequestOutput]:
856
857
        """Apply pooling to the hidden states corresponding to the input
        prompts.
858

859
        This class automatically batches the given prompts, considering
860
861
862
863
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
864
865
866
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
867
868
869
870
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
871
            prompt_adapter_request: Prompt Adapter request to use for
872
                generation, if any.
873
874

        Returns:
875
            A list of ``PoolingRequestOutput`` objects containing the
876
            pooled hidden states in the same order as the input prompts.
877
878
879
880
881

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
882
        """
883
884
885
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
886

887
888
889
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
890
                messages.append(
891
892
893
894
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
895
896

            raise ValueError(" ".join(messages))
897

898
        if prompt_token_ids is not None:
899
            parsed_prompts = self._convert_v1_inputs(
900
901
902
903
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
904
905
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
906

907
908
909
910
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

911
        self._validate_and_add_requests(
912
            prompts=parsed_prompts,
913
914
            params=pooling_params,
            lora_request=lora_request,
915
            prompt_adapter_request=prompt_adapter_request,
916
917
        )

918
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
919
        return self.engine_class.validate_outputs(outputs,
920
                                                  PoolingRequestOutput)
921

922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[EmbeddingRequestOutput]:
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[ClassificationRequestOutput]:
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
        text_1: List[Union[str, TextPrompt, TokensPrompt]],
        text_2: List[Union[str, TextPrompt, TokensPrompt]],
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[ScoringRequestOutput]:

        encoded_output = self.encode(
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
        encoded_output_1 = encoded_output[0:len(text_1)]
        encoded_output_2 = encoded_output[len(text_1):]

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

        output_pairs = [(t1, t2)
                        for t1, t2 in zip(encoded_output_1, encoded_output_2)]

        scores = []
        scorer = torch.nn.CosineSimilarity(0)

        for embed_1, embed_2 in output_pairs:
            pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data)

            if (pad_token_id := getattr(tokenizer, "pad_token_id",
                                        None)) is not None:
                tokens = embed_1.prompt_token_ids + [
                    pad_token_id
                ] + embed_2.prompt_token_ids
            else:
                tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids

            scores.append(
                PoolingRequestOutput(
                    request_id=f"{embed_1.request_id}_{embed_2.request_id}",
                    outputs=pair_score,
                    prompt_token_ids=tokens,
                    finished=True))

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1054
1055
1056
        tokenizer: AnyTokenizer,
        text_1: List[str],
        text_2: List[str],
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[ScoringRequestOutput]:

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

        tokenization_kwargs: Dict[str, Any] = {}
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1103
1104
1105
1106
1107
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1108
        *,
1109
1110
1111
1112
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1113
1114
    ) -> List[ScoringRequestOutput]:
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1115

1116
1117
1118
1119
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1120
1121
1122
1123
1124
1125
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1126
                case it has to have the same length as the ``text_2`` list
1127
1128
1129
1130
1131
1132
1133
1134
1135
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1136
            A list of ``ScoringRequestOutput`` objects containing the
1137
1138
            generated scores in the same order as the input prompts.
        """
1139
1140
1141
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1142

1143
1144
1145
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1146
                messages.append(
1147
1148
1149
1150
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1151
1152
1153

            raise ValueError(" ".join(messages))

1154
        if self.llm_engine.model_config.task not in ("embed", "score"):
1155
            raise ValueError(
1156
                "Score API is only enabled for `--task embed or --task score`")
1157
1158
1159
1160

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1161
1162
        tokenizer = self.llm_engine.get_tokenizer()

1163
1164
1165
1166
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1167
                                     "supported for scoring")
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1179
        input_text_1: List[str] = [ensure_str(t) for t in text_1]
1180
1181
1182
1183

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1184
        input_text_2: List[str] = [ensure_str(t) for t in text_2]
1185

1186
        if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
1187
            raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
1188
        if len(input_text_1) == 0:
1189
            raise ValueError("At least one text element must be given")
1190
        if len(input_text_2) == 0:
1191
1192
            raise ValueError("At least one text_pair element must be given")

1193
        if self.llm_engine.model_config.is_cross_encoder:
1194
1195
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1196
1197
1198
1199
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1200
1201
1202
1203
1204
1205
1206
1207
1208

            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1209

1210
1211
1212
1213
1214
1215
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1216
1217
1218
    def reset_prefix_cache(self) -> bool:
        return self.llm_engine.reset_prefix_cache()

1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

        :param level: The sleep level. Level 1 sleep will offload the model 
            weights and discard the kv cache. The content of kv cache is 
            forgotten. Level 1 sleep is good for sleeping and waking up the 
            engine to run the same model again. The model weights are backed 
            up in CPU memory. Please make sure there's enough CPU memory to 
            store the model weights. Level 2 sleep will discard both the model 
            weights and the kv cache. The content of both the model weights 
            and kv cache is forgotten. Level 2 sleep is good for sleeping and 
            waking up the engine to run a different model or update the model, 
            where previous model weights are not needed. It reduces CPU memory 
            pressure.
        """
1237
        self.reset_prefix_cache()
1238
1239
1240
        self.llm_engine.sleep(level=level)

    def wake_up(self):
1241
1242
1243
        """
        Wake up the engine from sleep mode. See the :meth:`sleep` method
        for more details."""
1244
1245
        self.llm_engine.wake_up()

1246
1247
    # LEGACY
    def _convert_v1_inputs(
1248
1249
        self,
        prompts: Optional[Union[str, List[str]]],
1250
1251
1252
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
1253

1254
1255
1256
1257
1258
1259
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1260

1261
        num_requests = None
1262
1263
        if prompts is not None:
            num_requests = len(prompts)
1264
1265
1266
1267
1268
1269
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1270
            num_requests = len(prompt_token_ids)
1271
1272
1273
1274
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1275
        parsed_prompts: List[PromptType] = []
1276
        for i in range(num_requests):
1277
            item: PromptType
1278

1279
            if prompts is not None:
1280
1281
1282
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1283
            else:
1284
                raise AssertionError
1285

1286
            parsed_prompts.append(item)
1287

1288
        return parsed_prompts
1289
1290
1291

    def _validate_and_add_requests(
        self,
1292
        prompts: Union[PromptType, Sequence[PromptType]],
1293
1294
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1295
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1296
        prompt_adapter_request: Optional[PromptAdapterRequest],
1297
        guided_options: Optional[GuidedDecodingRequest] = None,
1298
        priority: Optional[List[int]] = None,
1299
    ) -> None:
1300
1301
1302
1303
1304
1305
1306
1307
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1308
        if isinstance(prompts, (str, dict)):
1309
            # Convert a single prompt to a list.
1310
            prompts = [prompts]
1311

1312
        num_requests = len(prompts)
1313
1314
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1315
                             "must be the same.")
1316
1317
1318
1319
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1320

1321
1322
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1323
                self._add_guided_params(sp, guided_options)
1324
1325
1326

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1327

Zhuohan Li's avatar
Zhuohan Li committed
1328
        # Add requests to the engine.
1329
        for i, prompt in enumerate(prompts):
1330
            self._add_request(
1331
                prompt,
1332
                params[i] if isinstance(params, Sequence) else params,
1333
1334
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1335
                prompt_adapter_request=prompt_adapter_request,
1336
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1337
            )
1338

1339
    def _add_request(
nunjunj's avatar
nunjunj committed
1340
        self,
1341
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1342
        params: Union[SamplingParams, PoolingParams],
1343
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1344
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1345
        priority: int = 0,
1346
1347
    ) -> None:
        request_id = str(next(self.request_counter))
1348
1349
        self.llm_engine.add_request(
            request_id,
1350
            prompt,
1351
1352
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1353
            prompt_adapter_request=prompt_adapter_request,
1354
            priority=priority,
nunjunj's avatar
nunjunj committed
1355
        )
1356

1357
    def _add_guided_params(
1358
1359
1360
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
1376
1377
        return params

1378
    def _run_engine(
1379
            self, *, use_tqdm: bool
1380
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
1381
1382
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1383
            num_requests = self.llm_engine.get_num_unfinished_requests()
1384
1385
1386
1387
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1388
1389
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1390
            )
1391

Zhuohan Li's avatar
Zhuohan Li committed
1392
        # Run the engine.
1393
        outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
1394
1395
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1396
1397
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1398
            for output in step_outputs:
1399
                if output.finished:
1400
1401
                    outputs.append(output)
                    if use_tqdm:
1402
1403
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1404
                            assert output.prompt_token_ids is not None
1405
1406
1407
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1408
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1409
1410
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1411
1412
1413
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1414
                        pbar.update(1)
1415

1416
1417
        if use_tqdm:
            pbar.close()
1418
1419
1420
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1421
        return sorted(outputs, key=lambda x: int(x.request_id))