llm.py 64.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
8
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
9

10
import cloudpickle
11
import torch.nn as nn
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar, deprecated
14

15
16
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
17
18
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
19
20
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
21
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
22
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
23
                                         ChatTemplateContentFormatOption,
24
25
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
26
27
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
28
29
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
30
from vllm.entrypoints.utils import _validate_truncation_size
31
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
32
from vllm.inputs.parse import parse_and_batch_prompt
33
from vllm.logger import init_logger
34
from vllm.lora.request import LoRARequest
35
36
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
37
from vllm.model_executor.layers.quantization import QuantizationMethods
38
39
40
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
41
from vllm.pooling_params import PoolingParams
42
from vllm.prompt_adapter.request import PromptAdapterRequest
43
44
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
45
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
46
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
47
from vllm.usage.usage_lib import UsageContext
48
49
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
50

51
52
53
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

54
55
logger = init_logger(__name__)

56
57
_R = TypeVar("_R", default=Any)

58
59

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
63
64
65
66
67
68
69
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
70
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
71
72
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
73
74
75
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
76
77
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
78
79
80
81
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
89
        quantization: The method used to quantize the model weights. Currently,
90
            we support "awq", "gptq", and "fp8" (experimental).
91
92
93
94
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
95
96
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
97
98
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
99
100
101
102
103
104
105
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
106
107
108
109
110
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
111
112
113
114
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
115
116
117
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
118
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
119
            When a sequence has context length larger than this, we fall back
120
121
122
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
123
124
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
125
126
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
127
        hf_token: The token to use as HTTP bearer authorization for remote files
128
            . If `True`, will use the token generated when running
129
            `huggingface-cli login` (stored in `~/.huggingface`).
130
131
132
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
133
134
135
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
136
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
137

138
139
    Note:
        This class is intended to be used for offline inference. For online
140
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
141
    """
142

143
    DEPRECATE_LEGACY: ClassVar[bool] = True
144
145
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

146
147
148
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
149
    [LLM.__init__][].
150
151
    """

152
153
154
155
156
157
158
159
160
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

161
162
163
164
165
166
167
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
168
169
170
    def __init__(
        self,
        model: str,
171
        tokenizer: Optional[str] = None,
172
        tokenizer_mode: TokenizerMode = "auto",
173
        skip_tokenizer_init: bool = False,
174
        trust_remote_code: bool = False,
175
        allowed_local_media_path: str = "",
176
        tensor_parallel_size: int = 1,
177
178
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
179
        revision: Optional[str] = None,
180
        tokenizer_revision: Optional[str] = None,
181
        seed: Optional[int] = None,
182
        gpu_memory_utilization: float = 0.9,
183
        swap_space: float = 4,
184
        cpu_offload_gb: float = 0,
185
        enforce_eager: bool = False,
186
        max_seq_len_to_capture: int = 8192,
187
        disable_custom_all_reduce: bool = False,
188
        disable_async_output_proc: bool = False,
189
        hf_token: Optional[Union[bool, str]] = None,
190
        hf_overrides: Optional[HfOverrides] = None,
191
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
192
193
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
194
        override_pooler_config: Optional[PoolerConfig] = None,
195
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
196
197
        **kwargs,
    ) -> None:
198
        """LLM constructor."""
199

200
201
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
202

203
204
205
206
207
208
209
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

210
211
212
        if hf_overrides is None:
            hf_overrides = {}

213
        if compilation_config is not None:
214
215
216
217
218
219
220
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
221
222
            else:
                compilation_config_instance = compilation_config
223
        else:
224
            compilation_config_instance = CompilationConfig()
225

Zhuohan Li's avatar
Zhuohan Li committed
226
        engine_args = EngineArgs(
227
            model=model,
228
            task=task,
229
            tokenizer=tokenizer,
230
            tokenizer_mode=tokenizer_mode,
231
            skip_tokenizer_init=skip_tokenizer_init,
232
            trust_remote_code=trust_remote_code,
233
            allowed_local_media_path=allowed_local_media_path,
234
235
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
236
            quantization=quantization,
237
            revision=revision,
238
            tokenizer_revision=tokenizer_revision,
239
240
241
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
242
            cpu_offload_gb=cpu_offload_gb,
243
            enforce_eager=enforce_eager,
244
            max_seq_len_to_capture=max_seq_len_to_capture,
245
            disable_custom_all_reduce=disable_custom_all_reduce,
246
            disable_async_output_proc=disable_async_output_proc,
247
            hf_token=hf_token,
248
            hf_overrides=hf_overrides,
249
            mm_processor_kwargs=mm_processor_kwargs,
250
            override_pooler_config=override_pooler_config,
251
            compilation_config=compilation_config_instance,
252
253
            **kwargs,
        )
254
255
256
257
258

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
259

260
        self.request_counter = Counter()
261
        self.default_sampling_params: Union[dict[str, Any], None] = None
262

263
264
265
266
267
268
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
269
270

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
271
        tokenizer_group = self.llm_engine.get_tokenizer_group()
272

273
274
275
276
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
277
            tokenizer_group.tokenizer = tokenizer
278
        else:
279
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
280

281
    def get_default_sampling_params(self) -> SamplingParams:
282
283
284
285
286
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
287
288
        return SamplingParams()

289
290
291
292
293
294
295
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
296
        *,
297
        use_tqdm: bool = True,
298
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
299
300
301
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
302
    ) -> list[RequestOutput]:
303
304
        ...

305
    @overload  # LEGACY: single (prompt + optional token ids)
306
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
307
308
309
310
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
311
312
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
313
        use_tqdm: bool = True,
314
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
315
316
317
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
318
    ) -> list[RequestOutput]:
319
320
321
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
322
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
323
324
    def generate(
        self,
325
        prompts: list[str],
326
        sampling_params: Optional[Union[SamplingParams,
327
328
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
329
        use_tqdm: bool = True,
330
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
331
332
333
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
334
    ) -> list[RequestOutput]:
335
336
337
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
338
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
339
340
341
342
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
343
                                        list[SamplingParams]]] = None,
344
        *,
345
        prompt_token_ids: list[int],
346
        use_tqdm: bool = True,
347
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
348
349
350
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
351
    ) -> list[RequestOutput]:
352
353
354
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
355
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
356
357
    def generate(
        self,
358
        prompts: Optional[list[str]] = None,
359
        sampling_params: Optional[Union[SamplingParams,
360
                                        list[SamplingParams]]] = None,
361
        *,
362
        prompt_token_ids: list[list[int]],
363
        use_tqdm: bool = True,
364
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
365
366
367
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
368
    ) -> list[RequestOutput]:
369
370
371
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
372
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
373
374
375
376
    def generate(
        self,
        prompts: None,
        sampling_params: None,
377
        prompt_token_ids: Union[list[int], list[list[int]]],
378
        use_tqdm: bool = True,
379
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
380
381
382
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
383
    ) -> list[RequestOutput]:
384
385
        ...

nunjunj's avatar
nunjunj committed
386
387
388
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
389
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
390
    )
391
392
    def generate(
        self,
393
        prompts: Union[Union[PromptType, Sequence[PromptType]],
394
                       Optional[Union[str, list[str]]]] = None,
395
396
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
397
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
398
        use_tqdm: bool = True,
399
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
400
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
401
        guided_options_request: Optional[Union[LLMGuidedOptions,
402
                                               GuidedDecodingRequest]] = None,
403
404
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
        """Generates the completions for the input prompts.

407
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
408
409
410
411
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
412
            prompts: The prompts to the LLM. You may pass a sequence of prompts
413
                for batch inference. See [PromptType][vllm.inputs.PromptType]
414
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
415
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
416
417
418
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
419
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
420
            use_tqdm: Whether to use tqdm to display the progress bar.
421
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
422
            prompt_adapter_request: Prompt Adapter request to use for
423
                generation, if any.
424
425
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
426
427

        Returns:
428
            A list of `RequestOutput` objects containing the
429
            generated completions in the same order as the input prompts.
430

431
432
433
434
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
435
        """
436
        runner_type = self.llm_engine.model_config.runner_type
437
        if runner_type not in ["generate", "transcription"]:
438
            messages = [
439
                "LLM.generate() is only supported for (conditional) generation "
440
441
442
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

443
444
445
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
446
                messages.append(
447
448
449
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
450
451

            raise ValueError(" ".join(messages))
452

453
        if prompt_token_ids is not None:
454
            parsed_prompts = self._convert_v1_inputs(
455
                prompts=cast(Optional[Union[str, list[str]]], prompts),
456
457
458
                prompt_token_ids=prompt_token_ids,
            )
        else:
459
460
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
461

462
463
464
465
466
467
468
469
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

470
471
        if sampling_params is None:
            # Use default sampling params.
472
            sampling_params = self.get_default_sampling_params()
473

474
        self._validate_and_add_requests(
475
            prompts=parsed_prompts,
476
            params=sampling_params,
477
            use_tqdm=use_tqdm,
478
            lora_request=lora_request,
479
            prompt_adapter_request=prompt_adapter_request,
480
            guided_options=guided_options_request,
481
482
            priority=priority,
        )
483

484
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
485
        return self.engine_class.validate_outputs(outputs, RequestOutput)
486

487
    def collective_rpc(self,
488
                       method: Union[str, Callable[..., _R]],
489
                       timeout: Optional[float] = None,
490
491
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
492
493
494
495
496
497
498
499
500
501
502
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
503
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
504
505
506
507
508
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
509

510
511
512
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
513
        """
514
515

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
516
517

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
518
        """
519
520
        Run a function directly on the model inside each worker,
        returning the result for each of them.
521
        """
522
523
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
524

525
526
    def beam_search(
        self,
527
        prompts: list[Union[TokensPrompt, TextPrompt]],
528
        params: BeamSearchParams,
529
    ) -> list[BeamSearchOutput]:
530
531
532
533
534
535
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
536
            params: The beam search parameters.
537
        """
538
539
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
540
541
542
543
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
544
545
546
547
548
549
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
550

551
552
553
554
555
556
557
558
559
560
561
562
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
563

564
565
566
567
568
569
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
570
                                            temperature=temperature)
571
        instances: list[BeamSearchInstance] = []
572
573

        for prompt in prompts:
574
575
576
577
578
579
580
581
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

582
583
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
584
585
586
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
587

588
589
            instances.append(
                BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
590
591

        for _ in range(max_tokens):
592
            all_beams: list[BeamSearchSequence] = list(
593
594
595
596
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
597
            instance_start_and_end: list[tuple[int, int]] = list(
598
599
600
601
602
603
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
604
                create_tokens_prompt_from_beam(beam) for beam in all_beams
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
628
                                logprobs=current_beam.logprobs + [logprobs],
629
                                cum_logprob=current_beam.cum_logprob +
630
631
632
633
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
634
635
636
637
638
639
640

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
641
                                      key=sort_beams_key,
642
643
644
645
646
647
648
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
649
                                      key=sort_beams_key,
650
651
652
653
654
655
656
657
658
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
659
660
    def chat(
        self,
661
662
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
663
        sampling_params: Optional[Union[SamplingParams,
664
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
665
666
667
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
668
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
669
        add_generation_prompt: bool = True,
670
        continue_final_message: bool = False,
671
        tools: Optional[list[dict[str, Any]]] = None,
672
        chat_template_kwargs: Optional[dict[str, Any]] = None,
673
674
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
675
        """
676
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
677

678
        The chat conversation is converted into a text prompt using the
679
        tokenizer and calls the [generate][] method to generate the
680
681
682
683
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
684
685

        Args:
686
687
            messages: A list of conversations or a single conversation.

688
689
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
690

nunjunj's avatar
nunjunj committed
691
692
693
694
695
696
697
698
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
699
                If not provided, the model's default chat template will be used.
700
701
            chat_template_content_format: The format to render message content.

702
703
704
705
706
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
707

708
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
709
                to each message.
710
            continue_final_message: If True, continues the final message in
711
                the conversation instead of starting a new one. Cannot be
712
                `True` if `add_generation_prompt` is also `True`.
713
714
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
715
716
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
717
718

        Returns:
719
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
720
721
            responses in the same order as the input messages.
        """
722
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
723

724
725
        # Handle multi and single conversations
        if is_list_of(messages, list):
726
727
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
728
                                    messages)
729
        else:
730
            # messages is list[...]
731
            list_of_messages = [
732
                cast(list[ChatCompletionMessageParam], messages)
733
            ]
734

735
        tokenizer = self.get_tokenizer(lora_request)
736
737
738
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
739
            tools,
740
741
            chat_template_content_format,
            tokenizer,
742
            model_config=model_config,
743
744
        )

745
746
747
748
749
750
751
752
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

753
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
754
755

        for msgs in list_of_messages:
756
757
758
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
759
            conversation, mm_data = parse_chat_messages(
760
761
762
763
764
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
765
766

            if isinstance(tokenizer, MistralTokenizer):
767
                prompt_token_ids = apply_mistral_chat_template(
768
769
                    tokenizer,
                    messages=msgs,
770
                    **_chat_template_kwargs,
771
772
                )
            else:
773
                prompt_str = apply_hf_chat_template(
774
                    tokenizer=tokenizer,
775
                    conversation=conversation,
776
                    model_config=model_config,
777
                    **_chat_template_kwargs,
778
                )
779
780
781
782
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
783

784
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
785
786
787
788

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

789
790
791
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

792
            prompts.append(prompt)
793

nunjunj's avatar
nunjunj committed
794
        return self.generate(
795
            prompts,
796
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
797
798
799
800
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

801
802
803
804
805
806
807
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
808
        *,
809
        truncate_prompt_tokens: Optional[int] = None,
810
        use_tqdm: bool = True,
811
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
812
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
813
    ) -> list[PoolingRequestOutput]:
814
815
        ...

816
    @overload  # LEGACY: single (prompt + optional token ids)
817
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
818
819
820
821
822
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
823
        prompt_token_ids: Optional[list[int]] = None,
824
        truncate_prompt_tokens: Optional[int] = None,
825
        use_tqdm: bool = True,
826
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
827
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
828
    ) -> list[PoolingRequestOutput]:
829
        ...
830

831
    @overload  # LEGACY: multi (prompt + optional token ids)
832
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
833
834
    def encode(
        self,
835
        prompts: list[str],
836
        pooling_params: Optional[Union[PoolingParams,
837
                                       Sequence[PoolingParams]]] = None,
838
        prompt_token_ids: Optional[list[list[int]]] = None,
839
        truncate_prompt_tokens: Optional[int] = None,
840
        use_tqdm: bool = True,
841
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
842
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
843
    ) -> list[PoolingRequestOutput]:
844
845
846
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
847
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
848
849
850
851
852
853
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
854
        prompt_token_ids: list[int],
855
        truncate_prompt_tokens: Optional[int] = None,
856
        use_tqdm: bool = True,
857
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
858
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
859
    ) -> list[PoolingRequestOutput]:
860
861
862
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
863
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
864
865
    def encode(
        self,
866
        prompts: Optional[list[str]] = None,
867
868
869
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
870
        prompt_token_ids: list[list[int]],
871
        truncate_prompt_tokens: Optional[int] = None,
872
        use_tqdm: bool = True,
873
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
874
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
875
    ) -> list[PoolingRequestOutput]:
876
877
878
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
879
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
880
881
882
883
    def encode(
        self,
        prompts: None,
        pooling_params: None,
884
        prompt_token_ids: Union[list[int], list[list[int]]],
885
        truncate_prompt_tokens: Optional[int] = None,
886
        use_tqdm: bool = True,
887
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
888
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
889
    ) -> list[PoolingRequestOutput]:
890
891
        ...

nunjunj's avatar
nunjunj committed
892
893
894
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
895
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
896
    )
897
898
    def encode(
        self,
899
        prompts: Union[Union[PromptType, Sequence[PromptType]],
900
                       Optional[Union[str, list[str]]]] = None,
901
902
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
903
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
904
        truncate_prompt_tokens: Optional[int] = None,
905
        use_tqdm: bool = True,
906
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
907
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
908
    ) -> list[PoolingRequestOutput]:
909
910
        """Apply pooling to the hidden states corresponding to the input
        prompts.
911

912
        This class automatically batches the given prompts, considering
913
914
915
916
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
917
            prompts: The prompts to the LLM. You may pass a sequence of prompts
918
                for batch inference. See [PromptType][vllm.inputs.PromptType]
919
                for more details about the format of each prompts.
920
921
922
923
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
924
            prompt_adapter_request: Prompt Adapter request to use for
925
                generation, if any.
926
927

        Returns:
928
            A list of `PoolingRequestOutput` objects containing the
929
            pooled hidden states in the same order as the input prompts.
930

931
932
933
934
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
935
        """
936
937
938
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
939

940
941
942
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
943
                messages.append(
944
945
946
947
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
948
949

            raise ValueError(" ".join(messages))
950

951
        if prompt_token_ids is not None:
952
            parsed_prompts = self._convert_v1_inputs(
953
                prompts=cast(Optional[Union[str, list[str]]], prompts),
954
955
956
                prompt_token_ids=prompt_token_ids,
            )
        else:
957
958
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
959

960
961
962
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
963
964
965
966
967
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
968

969
970
971
972
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

973
        self._validate_and_add_requests(
974
            prompts=parsed_prompts,
975
            params=pooling_params,
976
            use_tqdm=use_tqdm,
977
            lora_request=lora_request,
978
            tokenization_kwargs=tokenization_kwargs,
979
            prompt_adapter_request=prompt_adapter_request,
980
981
        )

982
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
983
        return self.engine_class.validate_outputs(outputs,
984
                                                  PoolingRequestOutput)
985

986
987
988
989
990
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
991
        truncate_prompt_tokens: Optional[int] = None,
992
        use_tqdm: bool = True,
993
994
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
995
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
996
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
997
    ) -> list[EmbeddingRequestOutput]:
998
999
1000
1001
1002
1003
1004
1005
1006
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1007
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1008
                for more details about the format of each prompts.
1009
1010
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1011
1012
1013
1014
1015
1016
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1017
            A list of `EmbeddingRequestOutput` objects containing the
1018
1019
1020
1021
1022
1023
1024
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1025
                            truncate_prompt_tokens=truncate_prompt_tokens,
1026
                            use_tqdm=use_tqdm,
1027
                            pooling_params=pooling_params,
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1039
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1040
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1041
    ) -> list[ClassificationRequestOutput]:
1042
1043
1044
1045
1046
1047
1048
1049
1050
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1051
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1052
1053
1054
1055
1056
1057
1058
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1059
            A list of `ClassificationRequestOutput` objects containing the
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1073
1074
1075
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1076
1077
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1078
1079
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1080
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1081
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1082
    ) -> list[ScoringRequestOutput]:
1083

1084
        encoded_output: list[PoolingRequestOutput] = self.encode(
1085
            text_1 + text_2,
1086
            truncate_prompt_tokens=truncate_prompt_tokens,
1087
1088
1089
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1090

1091
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1092
            0:len(text_1)]
1093
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1094
            len(text_1):]
1095
1096
1097
1098

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1099
1100
1101
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1102
1103
1104
1105
1106
1107
1108

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1109
        tokenizer: AnyTokenizer,
1110
1111
        text_1: list[str],
        text_2: list[str],
1112
1113
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1114
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1115
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1116
    ) -> list[ScoringRequestOutput]:
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1129
        tokenization_kwargs: dict[str, Any] = {}
1130
1131
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1147
            use_tqdm=use_tqdm,
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1158
1159
1160
1161
1162
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1163
        *,
1164
1165
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1166
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1167
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1168
    ) -> list[ScoringRequestOutput]:
1169
        """Generate similarity scores for all pairs `<text,text_pair>`.
1170

1171
1172
1173
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
        In the `1 - N` case the `text_1` sentence will be replicated `N`
        times to pair with the `text_2` sentences.
1174
        The input pairs are used to build a list of prompts for the
1175
1176
1177
1178
1179
1180
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1181
                case it has to have the same length as the `text_2` list
1182
            text_2: The texts to pair with the query to form the input
1183
                to the LLM. See [PromptType][vllm.inputs.PromptType] for
1184
1185
1186
1187
1188
1189
1190
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1191
            A list of `ScoringRequestOutput` objects containing the
1192
1193
            generated scores in the same order as the input prompts.
        """
1194
1195
1196
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1197

1198
1199
1200
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1201
                messages.append(
1202
1203
1204
1205
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1206
1207
1208

            raise ValueError(" ".join(messages))

1209
        if self.llm_engine.model_config.task not in ("embed", "score"):
1210
            raise ValueError(
1211
                "Score API is only enabled for `--task embed or --task score`")
1212
1213
1214
1215

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1216
1217
        tokenizer = self.llm_engine.get_tokenizer()

1218
1219
1220
1221
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1222
                                     "supported for scoring")
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1234
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1235
1236
1237
1238

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1239
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1240

1241
        _validate_score_input_lens(input_text_1, input_text_2)
1242

1243
        if self.llm_engine.model_config.is_cross_encoder:
1244
1245
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1246
1247
1248
1249
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1250
1251
1252
1253
1254
1255
1256
1257
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1258

1259
1260
1261
1262
1263
1264
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1265
1266
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1267

1268
1269
1270
1271
1272
1273
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1286
        """
1287
        self.reset_prefix_cache()
1288
1289
        self.llm_engine.sleep(level=level)

1290
    def wake_up(self, tags: Optional[list[str]] = None):
1291
        """
1292
        Wake up the engine from sleep mode. See the [sleep][] method
1293
1294
1295
1296
1297
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
1298
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1299
1300
1301
1302
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1303

1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1318
1319
    # LEGACY
    def _convert_v1_inputs(
1320
        self,
1321
1322
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1323
1324
    ):
        # skip_tokenizer_init is now checked in engine
1325

1326
1327
1328
1329
1330
1331
1332
1333
1334
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1335
1336
1337
1338
1339
1340
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1341
1342
        if prompts is not None:
            num_requests = len(prompts)
1343
        elif prompt_token_ids is not None:
1344
            num_requests = len(prompt_token_ids)
1345
        parsed_prompts: list[PromptType] = []
1346
        for i in range(num_requests):
1347
            item: PromptType
1348

1349
            if prompts is not None:
1350
1351
1352
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1353
            else:
1354
                raise AssertionError
1355

1356
            parsed_prompts.append(item)
1357

1358
        return parsed_prompts
1359
1360
1361

    def _validate_and_add_requests(
        self,
1362
        prompts: Union[PromptType, Sequence[PromptType]],
1363
1364
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1365
1366
        *,
        use_tqdm: bool,
1367
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1368
        prompt_adapter_request: Optional[PromptAdapterRequest],
1369
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1370
        guided_options: Optional[GuidedDecodingRequest] = None,
1371
        priority: Optional[list[int]] = None,
1372
    ) -> None:
1373
1374
1375
1376
1377
1378
1379
1380
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1381
        if isinstance(prompts, (str, dict)):
1382
            # Convert a single prompt to a list.
1383
            prompts = [prompts]
1384

1385
        num_requests = len(prompts)
1386
1387
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1388
                             "must be the same.")
1389
1390
1391
1392
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1393

1394
1395
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1396
                self._add_guided_params(sp, guided_options)
1397
1398
1399

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1400

Zhuohan Li's avatar
Zhuohan Li committed
1401
        # Add requests to the engine.
1402
1403
1404
1405
1406
        it = prompts
        if use_tqdm:
            it = tqdm(it, desc="Adding requests")

        for i, prompt in enumerate(it):
1407
            self._add_request(
1408
                prompt,
1409
                params[i] if isinstance(params, Sequence) else params,
1410
                tokenization_kwargs=tokenization_kwargs,
1411
1412
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1413
                prompt_adapter_request=prompt_adapter_request,
1414
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1415
            )
1416

1417
    def _add_request(
nunjunj's avatar
nunjunj committed
1418
        self,
1419
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1420
        params: Union[SamplingParams, PoolingParams],
1421
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1422
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1423
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1424
        priority: int = 0,
1425
1426
    ) -> None:
        request_id = str(next(self.request_counter))
1427
1428
        self.llm_engine.add_request(
            request_id,
1429
            prompt,
1430
1431
            params,
            lora_request=lora_request,
1432
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1433
            prompt_adapter_request=prompt_adapter_request,
1434
            priority=priority,
nunjunj's avatar
nunjunj committed
1435
        )
1436

1437
    def _add_guided_params(
1438
1439
1440
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1441
1442
1443
1444
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1445
            raise ValueError("Cannot set both guided_options_request and "
1446
1447
1448
1449
1450
1451
1452
1453
1454
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1455
1456
1457
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1458
1459
        return params

1460
    def _run_engine(
1461
            self, *, use_tqdm: bool
1462
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1463
1464
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1465
            num_requests = self.llm_engine.get_num_unfinished_requests()
1466
1467
1468
1469
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1470
1471
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1472
            )
1473

Zhuohan Li's avatar
Zhuohan Li committed
1474
        # Run the engine.
1475
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1476
1477
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1478
1479
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1480
            for output in step_outputs:
1481
                if output.finished:
1482
1483
                    outputs.append(output)
                    if use_tqdm:
1484
1485
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1486
                            n = len(output.outputs)
1487
                            assert output.prompt_token_ids is not None
1488
                            total_in_toks += len(output.prompt_token_ids) * n
1489
1490
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1491
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1492
1493
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1494
1495
1496
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1497
                            pbar.update(n)
1498
1499
                        else:
                            pbar.update(1)
1500

1501
1502
        if use_tqdm:
            pbar.close()
1503
1504
1505
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1506
        return sorted(outputs, key=lambda x: int(x.request_id))