"vllm/vscode:/vscode.git/clone" did not exist on "437c3ce02615443ab166f4155028c1d81ee27c06"
llm.py 64.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
8
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
9

10
import cloudpickle
11
import torch.nn as nn
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar, deprecated
14

15
16
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
17
18
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
19
20
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
21
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
22
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
23
                                         ChatTemplateContentFormatOption,
24
25
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
26
27
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
28
29
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
30
from vllm.entrypoints.utils import _validate_truncation_size
31
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
32
from vllm.inputs.parse import parse_and_batch_prompt
33
from vllm.logger import init_logger
34
from vllm.lora.request import LoRARequest
35
36
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
37
from vllm.model_executor.layers.quantization import QuantizationMethods
38
39
40
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
41
from vllm.pooling_params import PoolingParams
42
from vllm.prompt_adapter.request import PromptAdapterRequest
43
44
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
45
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
46
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
47
from vllm.usage.usage_lib import UsageContext
48
49
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
50

51
52
53
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

54
55
logger = init_logger(__name__)

56
57
_R = TypeVar("_R", default=Any)

58
59

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
63
64
65
66
67
68
69
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
70
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
71
72
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
73
74
75
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
76
77
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
78
79
80
81
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
89
        quantization: The method used to quantize the model weights. Currently,
90
            we support "awq", "gptq", and "fp8" (experimental).
91
92
93
94
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
95
96
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
97
98
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
99
100
101
102
103
104
105
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
106
107
108
109
110
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
111
112
113
114
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
115
116
117
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
118
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
119
            When a sequence has context length larger than this, we fall back
120
121
122
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
123
124
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
125
126
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
127
        hf_token: The token to use as HTTP bearer authorization for remote files
128
            . If `True`, will use the token generated when running
129
            `huggingface-cli login` (stored in `~/.huggingface`).
130
131
132
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
133
134
135
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
136
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
137

138
139
    Note:
        This class is intended to be used for offline inference. For online
140
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
141
    """
142

143
    DEPRECATE_LEGACY: ClassVar[bool] = True
144
145
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

146
147
148
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
149
    [LLM.__init__][].
150
151
    """

152
153
154
155
156
157
158
159
160
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

161
162
163
164
165
166
167
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
168
169
170
    def __init__(
        self,
        model: str,
171
        tokenizer: Optional[str] = None,
172
        tokenizer_mode: TokenizerMode = "auto",
173
        skip_tokenizer_init: bool = False,
174
        trust_remote_code: bool = False,
175
        allowed_local_media_path: str = "",
176
        tensor_parallel_size: int = 1,
177
178
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
179
        revision: Optional[str] = None,
180
        tokenizer_revision: Optional[str] = None,
181
        seed: Optional[int] = None,
182
        gpu_memory_utilization: float = 0.9,
183
        swap_space: float = 4,
184
        cpu_offload_gb: float = 0,
185
        enforce_eager: bool = False,
186
        max_seq_len_to_capture: int = 8192,
187
        disable_custom_all_reduce: bool = False,
188
        disable_async_output_proc: bool = False,
189
        hf_token: Optional[Union[bool, str]] = None,
190
        hf_overrides: Optional[HfOverrides] = None,
191
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
192
193
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
194
        override_pooler_config: Optional[PoolerConfig] = None,
195
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
196
197
        **kwargs,
    ) -> None:
198
        """LLM constructor."""
199

200
201
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
202

203
204
205
206
207
208
209
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

210
        if compilation_config is not None:
211
212
213
214
215
216
217
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
218
219
            else:
                compilation_config_instance = compilation_config
220
221
222
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
223
        engine_args = EngineArgs(
224
            model=model,
225
            task=task,
226
            tokenizer=tokenizer,
227
            tokenizer_mode=tokenizer_mode,
228
            skip_tokenizer_init=skip_tokenizer_init,
229
            trust_remote_code=trust_remote_code,
230
            allowed_local_media_path=allowed_local_media_path,
231
232
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
233
            quantization=quantization,
234
            revision=revision,
235
            tokenizer_revision=tokenizer_revision,
236
237
238
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
239
            cpu_offload_gb=cpu_offload_gb,
240
            enforce_eager=enforce_eager,
241
            max_seq_len_to_capture=max_seq_len_to_capture,
242
            disable_custom_all_reduce=disable_custom_all_reduce,
243
            disable_async_output_proc=disable_async_output_proc,
244
            hf_token=hf_token,
245
            hf_overrides=hf_overrides,
246
            mm_processor_kwargs=mm_processor_kwargs,
247
            override_pooler_config=override_pooler_config,
248
            compilation_config=compilation_config_instance,
249
250
            **kwargs,
        )
251
252
253
254
255

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
256

257
        self.request_counter = Counter()
258
        self.default_sampling_params: Union[dict[str, Any], None] = None
259

260
261
262
263
264
265
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
266
267

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
268
        tokenizer_group = self.llm_engine.get_tokenizer_group()
269

270
271
272
273
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
274
            tokenizer_group.tokenizer = tokenizer
275
        else:
276
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
277

278
    def get_default_sampling_params(self) -> SamplingParams:
279
280
281
282
283
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
284
285
        return SamplingParams()

286
287
288
289
290
291
292
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
293
        *,
294
        use_tqdm: bool = True,
295
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
296
297
298
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
299
    ) -> list[RequestOutput]:
300
301
        ...

302
    @overload  # LEGACY: single (prompt + optional token ids)
303
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
304
305
306
307
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
308
309
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
310
        use_tqdm: bool = True,
311
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
312
313
314
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
315
    ) -> list[RequestOutput]:
316
317
318
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
319
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
320
321
    def generate(
        self,
322
        prompts: list[str],
323
        sampling_params: Optional[Union[SamplingParams,
324
325
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
326
        use_tqdm: bool = True,
327
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
328
329
330
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
331
    ) -> list[RequestOutput]:
332
333
334
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
335
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
336
337
338
339
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
340
                                        list[SamplingParams]]] = None,
341
        *,
342
        prompt_token_ids: list[int],
343
        use_tqdm: bool = True,
344
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
345
346
347
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
348
    ) -> list[RequestOutput]:
349
350
351
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
352
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
353
354
    def generate(
        self,
355
        prompts: Optional[list[str]] = None,
356
        sampling_params: Optional[Union[SamplingParams,
357
                                        list[SamplingParams]]] = None,
358
        *,
359
        prompt_token_ids: list[list[int]],
360
        use_tqdm: bool = True,
361
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
362
363
364
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
365
    ) -> list[RequestOutput]:
366
367
368
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
369
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
370
371
372
373
    def generate(
        self,
        prompts: None,
        sampling_params: None,
374
        prompt_token_ids: Union[list[int], list[list[int]]],
375
        use_tqdm: bool = True,
376
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
377
378
379
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
380
    ) -> list[RequestOutput]:
381
382
        ...

nunjunj's avatar
nunjunj committed
383
384
385
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
386
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
387
    )
388
389
    def generate(
        self,
390
        prompts: Union[Union[PromptType, Sequence[PromptType]],
391
                       Optional[Union[str, list[str]]]] = None,
392
393
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
394
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
395
        use_tqdm: bool = True,
396
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
397
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
398
        guided_options_request: Optional[Union[LLMGuidedOptions,
399
                                               GuidedDecodingRequest]] = None,
400
401
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
402
403
        """Generates the completions for the input prompts.

404
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
407
408
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
409
            prompts: The prompts to the LLM. You may pass a sequence of prompts
410
                for batch inference. See [PromptType][vllm.inputs.PromptType]
411
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
412
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
413
414
415
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
416
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
417
            use_tqdm: Whether to use tqdm to display the progress bar.
418
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
419
            prompt_adapter_request: Prompt Adapter request to use for
420
                generation, if any.
421
422
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
423
424

        Returns:
425
            A list of `RequestOutput` objects containing the
426
            generated completions in the same order as the input prompts.
427

428
429
430
431
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
432
        """
433
        runner_type = self.llm_engine.model_config.runner_type
434
        if runner_type not in ["generate", "transcription"]:
435
            messages = [
436
                "LLM.generate() is only supported for (conditional) generation "
437
438
439
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

440
441
442
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
443
                messages.append(
444
445
446
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
447
448

            raise ValueError(" ".join(messages))
449

450
        if prompt_token_ids is not None:
451
            parsed_prompts = self._convert_v1_inputs(
452
                prompts=cast(Optional[Union[str, list[str]]], prompts),
453
454
455
                prompt_token_ids=prompt_token_ids,
            )
        else:
456
457
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
458

459
460
461
462
463
464
465
466
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

467
468
        if sampling_params is None:
            # Use default sampling params.
469
            sampling_params = self.get_default_sampling_params()
470

471
        self._validate_and_add_requests(
472
            prompts=parsed_prompts,
473
            params=sampling_params,
474
            use_tqdm=use_tqdm,
475
            lora_request=lora_request,
476
            prompt_adapter_request=prompt_adapter_request,
477
            guided_options=guided_options_request,
478
479
            priority=priority,
        )
480

481
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
482
        return self.engine_class.validate_outputs(outputs, RequestOutput)
483

484
    def collective_rpc(self,
485
                       method: Union[str, Callable[..., _R]],
486
                       timeout: Optional[float] = None,
487
488
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
489
490
491
492
493
494
495
496
497
498
499
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
500
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
501
502
503
504
505
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
506

507
508
509
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
510
        """
511
512

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
513
514

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
515
        """
516
517
        Run a function directly on the model inside each worker,
        returning the result for each of them.
518
        """
519
520
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
521

522
523
    def beam_search(
        self,
524
        prompts: list[Union[TokensPrompt, TextPrompt]],
525
        params: BeamSearchParams,
526
    ) -> list[BeamSearchOutput]:
527
528
529
530
531
532
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
533
            params: The beam search parameters.
534
        """
535
536
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
537
538
539
540
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
541
542
543
544
545
546
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
547

548
549
550
551
552
553
554
555
556
557
558
559
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
560

561
562
563
564
565
566
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
567
                                            temperature=temperature)
568
        instances: list[BeamSearchInstance] = []
569
570

        for prompt in prompts:
571
572
573
574
575
576
577
578
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

579
580
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
581
582
583
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
584

585
586
            instances.append(
                BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
587
588

        for _ in range(max_tokens):
589
            all_beams: list[BeamSearchSequence] = list(
590
591
592
593
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
594
            instance_start_and_end: list[tuple[int, int]] = list(
595
596
597
598
599
600
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
601
                create_tokens_prompt_from_beam(beam) for beam in all_beams
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
625
                                logprobs=current_beam.logprobs + [logprobs],
626
                                cum_logprob=current_beam.cum_logprob +
627
628
629
630
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
631
632
633
634
635
636
637

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
638
                                      key=sort_beams_key,
639
640
641
642
643
644
645
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
646
                                      key=sort_beams_key,
647
648
649
650
651
652
653
654
655
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
656
657
    def chat(
        self,
658
659
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
660
        sampling_params: Optional[Union[SamplingParams,
661
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
662
663
664
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
665
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
666
        add_generation_prompt: bool = True,
667
        continue_final_message: bool = False,
668
        tools: Optional[list[dict[str, Any]]] = None,
669
        chat_template_kwargs: Optional[dict[str, Any]] = None,
670
671
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
672
        """
673
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
674

675
        The chat conversation is converted into a text prompt using the
676
        tokenizer and calls the [generate][] method to generate the
677
678
679
680
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
681
682

        Args:
683
684
            messages: A list of conversations or a single conversation.

685
686
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
687

nunjunj's avatar
nunjunj committed
688
689
690
691
692
693
694
695
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
696
                If not provided, the model's default chat template will be used.
697
698
            chat_template_content_format: The format to render message content.

699
700
701
702
703
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
704

705
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
706
                to each message.
707
            continue_final_message: If True, continues the final message in
708
                the conversation instead of starting a new one. Cannot be
709
                `True` if `add_generation_prompt` is also `True`.
710
711
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
712
713
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
714
715

        Returns:
716
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
717
718
            responses in the same order as the input messages.
        """
719
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
720

721
722
        # Handle multi and single conversations
        if is_list_of(messages, list):
723
724
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
725
                                    messages)
726
        else:
727
            # messages is list[...]
728
            list_of_messages = [
729
                cast(list[ChatCompletionMessageParam], messages)
730
            ]
731

732
        tokenizer = self.get_tokenizer(lora_request)
733
734
735
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
736
            tools,
737
738
            chat_template_content_format,
            tokenizer,
739
            model_config=model_config,
740
741
        )

742
743
744
745
746
747
748
749
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

750
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
751
752

        for msgs in list_of_messages:
753
754
755
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
756
            conversation, mm_data = parse_chat_messages(
757
758
759
760
761
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
762
763

            if isinstance(tokenizer, MistralTokenizer):
764
                prompt_token_ids = apply_mistral_chat_template(
765
766
                    tokenizer,
                    messages=msgs,
767
                    **_chat_template_kwargs,
768
769
                )
            else:
770
                prompt_str = apply_hf_chat_template(
771
                    tokenizer=tokenizer,
772
                    conversation=conversation,
773
                    model_config=model_config,
774
                    **_chat_template_kwargs,
775
                )
776
777
778
779
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
780

781
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
782
783
784
785

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

786
787
788
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

789
            prompts.append(prompt)
790

nunjunj's avatar
nunjunj committed
791
        return self.generate(
792
            prompts,
793
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
794
795
796
797
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

798
799
800
801
802
803
804
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
805
        *,
806
        truncate_prompt_tokens: Optional[int] = None,
807
        use_tqdm: bool = True,
808
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
809
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
810
    ) -> list[PoolingRequestOutput]:
811
812
        ...

813
    @overload  # LEGACY: single (prompt + optional token ids)
814
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
815
816
817
818
819
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
820
        prompt_token_ids: Optional[list[int]] = None,
821
        truncate_prompt_tokens: Optional[int] = None,
822
        use_tqdm: bool = True,
823
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
824
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
825
    ) -> list[PoolingRequestOutput]:
826
        ...
827

828
    @overload  # LEGACY: multi (prompt + optional token ids)
829
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
830
831
    def encode(
        self,
832
        prompts: list[str],
833
        pooling_params: Optional[Union[PoolingParams,
834
                                       Sequence[PoolingParams]]] = None,
835
        prompt_token_ids: Optional[list[list[int]]] = None,
836
        truncate_prompt_tokens: Optional[int] = None,
837
        use_tqdm: bool = True,
838
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
839
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
840
    ) -> list[PoolingRequestOutput]:
841
842
843
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
844
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
845
846
847
848
849
850
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
851
        prompt_token_ids: list[int],
852
        truncate_prompt_tokens: Optional[int] = None,
853
        use_tqdm: bool = True,
854
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
855
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
856
    ) -> list[PoolingRequestOutput]:
857
858
859
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
860
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
861
862
    def encode(
        self,
863
        prompts: Optional[list[str]] = None,
864
865
866
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
867
        prompt_token_ids: list[list[int]],
868
        truncate_prompt_tokens: Optional[int] = None,
869
        use_tqdm: bool = True,
870
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
871
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
872
    ) -> list[PoolingRequestOutput]:
873
874
875
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
876
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
877
878
879
880
    def encode(
        self,
        prompts: None,
        pooling_params: None,
881
        prompt_token_ids: Union[list[int], list[list[int]]],
882
        truncate_prompt_tokens: Optional[int] = None,
883
        use_tqdm: bool = True,
884
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
885
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
886
    ) -> list[PoolingRequestOutput]:
887
888
        ...

nunjunj's avatar
nunjunj committed
889
890
891
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
892
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
893
    )
894
895
    def encode(
        self,
896
        prompts: Union[Union[PromptType, Sequence[PromptType]],
897
                       Optional[Union[str, list[str]]]] = None,
898
899
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
900
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
901
        truncate_prompt_tokens: Optional[int] = None,
902
        use_tqdm: bool = True,
903
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
904
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
905
    ) -> list[PoolingRequestOutput]:
906
907
        """Apply pooling to the hidden states corresponding to the input
        prompts.
908

909
        This class automatically batches the given prompts, considering
910
911
912
913
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
914
            prompts: The prompts to the LLM. You may pass a sequence of prompts
915
                for batch inference. See [PromptType][vllm.inputs.PromptType]
916
                for more details about the format of each prompts.
917
918
919
920
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
921
            prompt_adapter_request: Prompt Adapter request to use for
922
                generation, if any.
923
924

        Returns:
925
            A list of `PoolingRequestOutput` objects containing the
926
            pooled hidden states in the same order as the input prompts.
927

928
929
930
931
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
932
        """
933
934
935
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
936

937
938
939
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
940
                messages.append(
941
942
943
944
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
945
946

            raise ValueError(" ".join(messages))
947

948
        if prompt_token_ids is not None:
949
            parsed_prompts = self._convert_v1_inputs(
950
                prompts=cast(Optional[Union[str, list[str]]], prompts),
951
952
953
                prompt_token_ids=prompt_token_ids,
            )
        else:
954
955
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
956

957
958
959
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
960
961
962
963
964
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
965

966
967
968
969
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

970
        self._validate_and_add_requests(
971
            prompts=parsed_prompts,
972
            params=pooling_params,
973
            use_tqdm=use_tqdm,
974
            lora_request=lora_request,
975
            tokenization_kwargs=tokenization_kwargs,
976
            prompt_adapter_request=prompt_adapter_request,
977
978
        )

979
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
980
        return self.engine_class.validate_outputs(outputs,
981
                                                  PoolingRequestOutput)
982

983
984
985
986
987
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
988
        truncate_prompt_tokens: Optional[int] = None,
989
        use_tqdm: bool = True,
990
991
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
992
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
993
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
994
    ) -> list[EmbeddingRequestOutput]:
995
996
997
998
999
1000
1001
1002
1003
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1004
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1005
                for more details about the format of each prompts.
1006
1007
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1008
1009
1010
1011
1012
1013
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1014
            A list of `EmbeddingRequestOutput` objects containing the
1015
1016
1017
1018
1019
1020
1021
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1022
                            truncate_prompt_tokens=truncate_prompt_tokens,
1023
                            use_tqdm=use_tqdm,
1024
                            pooling_params=pooling_params,
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1036
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1037
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1038
    ) -> list[ClassificationRequestOutput]:
1039
1040
1041
1042
1043
1044
1045
1046
1047
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1048
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1049
1050
1051
1052
1053
1054
1055
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1056
            A list of `ClassificationRequestOutput` objects containing the
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1070
1071
1072
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1073
1074
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1075
1076
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1077
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1078
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1079
    ) -> list[ScoringRequestOutput]:
1080

1081
        encoded_output: list[PoolingRequestOutput] = self.encode(
1082
            text_1 + text_2,
1083
            truncate_prompt_tokens=truncate_prompt_tokens,
1084
1085
1086
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1087

1088
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1089
            0:len(text_1)]
1090
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1091
            len(text_1):]
1092
1093
1094
1095

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1096
1097
1098
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1099
1100
1101
1102
1103
1104
1105

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1106
        tokenizer: AnyTokenizer,
1107
1108
        text_1: list[str],
        text_2: list[str],
1109
1110
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1111
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1112
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1113
    ) -> list[ScoringRequestOutput]:
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1126
        tokenization_kwargs: dict[str, Any] = {}
1127
1128
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1144
            use_tqdm=use_tqdm,
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1155
1156
1157
1158
1159
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1160
        *,
1161
1162
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1163
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1164
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1165
    ) -> list[ScoringRequestOutput]:
1166
        """Generate similarity scores for all pairs `<text,text_pair>`.
1167

1168
1169
1170
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
        In the `1 - N` case the `text_1` sentence will be replicated `N`
        times to pair with the `text_2` sentences.
1171
        The input pairs are used to build a list of prompts for the
1172
1173
1174
1175
1176
1177
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1178
                case it has to have the same length as the `text_2` list
1179
            text_2: The texts to pair with the query to form the input
1180
                to the LLM. See [PromptType][vllm.inputs.PromptType] for
1181
1182
1183
1184
1185
1186
1187
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1188
            A list of `ScoringRequestOutput` objects containing the
1189
1190
            generated scores in the same order as the input prompts.
        """
1191
1192
1193
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1194

1195
1196
1197
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1198
                messages.append(
1199
1200
1201
1202
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1203
1204
1205

            raise ValueError(" ".join(messages))

1206
        if self.llm_engine.model_config.task not in ("embed", "score"):
1207
            raise ValueError(
1208
                "Score API is only enabled for `--task embed or --task score`")
1209
1210
1211
1212

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1213
1214
        tokenizer = self.llm_engine.get_tokenizer()

1215
1216
1217
1218
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1219
                                     "supported for scoring")
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1231
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1232
1233
1234
1235

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1236
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1237

1238
        _validate_score_input_lens(input_text_1, input_text_2)
1239

1240
        if self.llm_engine.model_config.is_cross_encoder:
1241
1242
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1243
1244
1245
1246
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1247
1248
1249
1250
1251
1252
1253
1254
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1255

1256
1257
1258
1259
1260
1261
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1262
1263
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1264

1265
1266
1267
1268
1269
1270
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1283
        """
1284
        self.reset_prefix_cache()
1285
1286
        self.llm_engine.sleep(level=level)

1287
    def wake_up(self, tags: Optional[list[str]] = None):
1288
        """
1289
        Wake up the engine from sleep mode. See the [sleep][] method
1290
1291
1292
1293
1294
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
1295
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1296
1297
1298
1299
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1300

1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1315
1316
    # LEGACY
    def _convert_v1_inputs(
1317
        self,
1318
1319
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1320
1321
    ):
        # skip_tokenizer_init is now checked in engine
1322

1323
1324
1325
1326
1327
1328
1329
1330
1331
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1332
1333
1334
1335
1336
1337
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1338
1339
        if prompts is not None:
            num_requests = len(prompts)
1340
        elif prompt_token_ids is not None:
1341
            num_requests = len(prompt_token_ids)
1342
        parsed_prompts: list[PromptType] = []
1343
        for i in range(num_requests):
1344
            item: PromptType
1345

1346
            if prompts is not None:
1347
1348
1349
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1350
            else:
1351
                raise AssertionError
1352

1353
            parsed_prompts.append(item)
1354

1355
        return parsed_prompts
1356
1357
1358

    def _validate_and_add_requests(
        self,
1359
        prompts: Union[PromptType, Sequence[PromptType]],
1360
1361
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1362
1363
        *,
        use_tqdm: bool,
1364
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1365
        prompt_adapter_request: Optional[PromptAdapterRequest],
1366
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1367
        guided_options: Optional[GuidedDecodingRequest] = None,
1368
        priority: Optional[list[int]] = None,
1369
    ) -> None:
1370
1371
1372
1373
1374
1375
1376
1377
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1378
        if isinstance(prompts, (str, dict)):
1379
            # Convert a single prompt to a list.
1380
            prompts = [prompts]
1381

1382
        num_requests = len(prompts)
1383
1384
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1385
                             "must be the same.")
1386
1387
1388
1389
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1390

1391
1392
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1393
                self._add_guided_params(sp, guided_options)
1394
1395
1396

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1397

Zhuohan Li's avatar
Zhuohan Li committed
1398
        # Add requests to the engine.
1399
1400
1401
1402
1403
        it = prompts
        if use_tqdm:
            it = tqdm(it, desc="Adding requests")

        for i, prompt in enumerate(it):
1404
            self._add_request(
1405
                prompt,
1406
                params[i] if isinstance(params, Sequence) else params,
1407
                tokenization_kwargs=tokenization_kwargs,
1408
1409
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1410
                prompt_adapter_request=prompt_adapter_request,
1411
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1412
            )
1413

1414
    def _add_request(
nunjunj's avatar
nunjunj committed
1415
        self,
1416
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1417
        params: Union[SamplingParams, PoolingParams],
1418
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1419
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1420
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1421
        priority: int = 0,
1422
1423
    ) -> None:
        request_id = str(next(self.request_counter))
1424
1425
        self.llm_engine.add_request(
            request_id,
1426
            prompt,
1427
1428
            params,
            lora_request=lora_request,
1429
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1430
            prompt_adapter_request=prompt_adapter_request,
1431
            priority=priority,
nunjunj's avatar
nunjunj committed
1432
        )
1433

1434
    def _add_guided_params(
1435
1436
1437
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1438
1439
1440
1441
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1442
            raise ValueError("Cannot set both guided_options_request and "
1443
1444
1445
1446
1447
1448
1449
1450
1451
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1452
1453
1454
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1455
1456
        return params

1457
    def _run_engine(
1458
            self, *, use_tqdm: bool
1459
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1460
1461
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1462
            num_requests = self.llm_engine.get_num_unfinished_requests()
1463
1464
1465
1466
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1467
1468
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1469
            )
1470

Zhuohan Li's avatar
Zhuohan Li committed
1471
        # Run the engine.
1472
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1473
1474
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1475
1476
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1477
            for output in step_outputs:
1478
                if output.finished:
1479
1480
                    outputs.append(output)
                    if use_tqdm:
1481
1482
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1483
                            n = len(output.outputs)
1484
                            assert output.prompt_token_ids is not None
1485
                            total_in_toks += len(output.prompt_token_ids) * n
1486
1487
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1488
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1489
1490
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1491
1492
1493
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1494
                            pbar.update(n)
1495
1496
                        else:
                            pbar.update(1)
1497

1498
1499
        if use_tqdm:
            pbar.close()
1500
1501
1502
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1503
        return sorted(outputs, key=lambda x: int(x.request_id))